colossalai.utils.activation_checkpoint
- colossalai.utils.activation_checkpoint.checkpoint(function, activation_offload, *args)
Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint
- Parameters
function – Describe the forward pass function. It should know how to handle the input tuples.
args – Tuple containing the parameters of the function
- Returns
Output of running function with provided args