colossalai.amp.apex_amp.apex_amp

class colossalai.amp.apex_amp.apex_amp.ApexAMPOptimizer(optim)

A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm methods

backward(loss)

Backward pass to get all gradients

Parameters

loss (torch.Tensor) – Loss computed by a loss function

clip_grad_norm(model, max_norm)

Clip gradients’ norm

Parameters
  • model (torch.nn.Module) – Your model object

  • max_norm (float) – The max norm value for gradient clipping