colossalai.amp.torch_amp

colossalai.amp.torch_amp.convert_to_torch_amp(model, optimizer, criterion=None, amp_config=None)

A helper function to wrap training components with Torch AMP modules

Parameters
  • model (torch.nn.Module) – your model object

  • optimizer (torch.optim.Optimizer) – your optimizer object

  • criterion (torch.nn.modules.loss._Loss, optional) – your loss function object

  • amp_config (colossalai.context.Config or dict, optional) – configuration for different amp modes

Returns

(model, optimizer, criterion)

Return type

Tuple

class colossalai.amp.torch_amp.TorchAMPModel(model)

A wrapper class for a model object which executes forward with values automatically cast to fp16

class colossalai.amp.torch_amp.TorchAMPLoss(loss)

A wrapper class for a criterion object which computes the loss in mixed-precision context

Parameters

loss (torch.nn.modules.loss._Loss) – A loss function object

class colossalai.amp.torch_amp.TorchAMPOptimizer(optim, *args, **kwargs)

A wrapper class which integrate pytorch amp with an optimizer

Parameters
  • optim (torch.optim.Optimizer) – A normal optimizer like Adam or SGD

  • args – Args used to initialize gradient scaler

  • kwargs – Kwargs used to initialize gradient scaler

backward(loss)

Backward with torch amp gradient scaler

Parameters

loss (torch.Tensor) – Loss computed by a loss function

step()

Update the parameters of the model

clip_grad_norm(model, max_norm)

Apply gradient clipping to the model parameters

Parameters
  • model (torch.nn.Module) – Your model object

  • max_norm (float) – Max norm value for gradient clipping