colossalai.utils.activation_checkpoint

colossalai.utils.activation_checkpoint.checkpoint(function, activation_offload, *args)

Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint

Parameters
  • function – Describe the forward pass function. It should know how to handle the input tuples.

  • args – Tuple containing the parameters of the function

Returns

Output of running function with provided args