colossalai.amp.apex_amp.apex_amp
- class colossalai.amp.apex_amp.apex_amp.ApexAMPOptimizer(optim)
A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm methods
- backward(loss)
Backward pass to get all gradients
- Parameters
loss (torch.Tensor) – Loss computed by a loss function
- clip_grad_norm(model, max_norm)
Clip gradients’ norm
- Parameters
model (torch.nn.Module) – Your model object
max_norm (float) – The max norm value for gradient clipping