colossalai.amp.torch_amp.torch_amp

class colossalai.amp.torch_amp.torch_amp.TorchAMPOptimizer(optim, *args, **kwargs)

A wrapper class which integrate pytorch amp with an optimizer

Parameters
  • optim (torch.optim.Optimizer) – A normal optimizer like Adam or SGD

  • args – Args used to initialize gradient scaler

  • kwargs – Kwargs used to initialize gradient scaler

backward(loss)

Backward with torch amp gradient scaler

Parameters

loss (torch.Tensor) – Loss computed by a loss function

step()

Update the parameters of the model

clip_grad_norm(model, max_norm)

Apply gradient clipping to the model parameters

Parameters
  • model (torch.nn.Module) – Your model object

  • max_norm (float) – Max norm value for gradient clipping

class colossalai.amp.torch_amp.torch_amp.TorchAMPModel(model)

A wrapper class for a model object which executes forward with values automatically cast to fp16

class colossalai.amp.torch_amp.torch_amp.TorchAMPLoss(loss)

A wrapper class for a criterion object which computes the loss in mixed-precision context

Parameters

loss (torch.nn.modules.loss._Loss) – A loss function object