colossalai.amp.torch_amp

colossalai.amp.torch_amp.convert_to_torch_amp(model, optimizer, criterion=None, amp_config=None)[source]

A helper function to wrap training components with Pytorch AMP modules

Parameters
  • model (torch.nn.Module) – your model object.

  • optimizer (torch.optim.Optimizer) – your optimizer object

  • criterion (torch.nn.modules.loss._Loss, optional) – your loss function object

  • amp_config (colossalai.context.Config or dict, optional) – configuration for Pytorch AMP.

The amp_config should include parameters below:

init_scale (float, optional, default=2.**16)
growth_factor (float, optional, default=2.0)
backoff_factor (float, optional, default=0.5)
growth_interval (int, optional, default=2000)
enabled (bool, optional, default=True)
Returns

A tuple (model, optimizer, criterion)

class colossalai.amp.torch_amp.TorchAMPModel(model)[source]

A wrapper class for a model object which executes forward with values automatically cast to fp16

class colossalai.amp.torch_amp.TorchAMPLoss(loss)[source]

A wrapper class for a criterion object which computes the loss in mixed-precision context

Parameters

loss (torch.nn.modules.loss._Loss) – A loss function object

class colossalai.amp.torch_amp.TorchAMPOptimizer(optim, *args, **kwargs)[source]

A wrapper class which integrate Pytorch AMP with an optimizer

Parameters
  • optim (torch.optim.Optimizer) – A normal optimizer like Adam or SGD.

  • init_scale (float, optional, default=2.**16) – Initial scale factor.

  • growth_factor (float, optional, default=2.0) – Factor by which the scale is multiplied during update() if no inf/NaN gradients occur for growth_interval consecutive iterations.

  • backoff_factor (float, optional, default=0.5) – Factor by which the scale is multiplied during update() if inf/NaN gradients occur in an iteration.

  • growth_interval (int, optional, default=2000) – Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by growth_factor.

  • enabled (bool, optional, default=True) – If False, disables gradient scaling. step() simply invokes the underlying optimizer.step(), and other methods become no-ops.

backward(loss)[source]

Backward with torch amp gradient scaler

Parameters

loss (torch.Tensor) – Loss computed by a loss function

step()[source]

Update the parameters of the model

clip_grad_norm(model, max_norm)[source]

Apply gradient clipping to the model parameters

Parameters
  • model (torch.nn.Module) – Your model object

  • max_norm (float) – Max norm value for gradient clipping