colossalai.amp.torch_amp
- colossalai.amp.torch_amp.convert_to_torch_amp(model, optimizer, criterion=None, amp_config=None)[source]
A helper function to wrap training components with Pytorch AMP modules
- Parameters
model (
torch.nn.Module) – your model object.optimizer (
torch.optim.Optimizer) – your optimizer objectcriterion (
torch.nn.modules.loss._Loss, optional) – your loss function objectamp_config (
colossalai.context.Configor dict, optional) – configuration for Pytorch AMP.
The
amp_configshould include parameters below:init_scale (float, optional, default=2.**16) growth_factor (float, optional, default=2.0) backoff_factor (float, optional, default=0.5) growth_interval (int, optional, default=2000) enabled (bool, optional, default=True)
- Returns
A tuple (model, optimizer, criterion)
- class colossalai.amp.torch_amp.TorchAMPModel(model)[source]
A wrapper class for a model object which executes forward with values automatically cast to fp16
- class colossalai.amp.torch_amp.TorchAMPLoss(loss)[source]
A wrapper class for a criterion object which computes the loss in mixed-precision context
- Parameters
loss (torch.nn.modules.loss._Loss) – A loss function object
- class colossalai.amp.torch_amp.TorchAMPOptimizer(optim, *args, **kwargs)[source]
A wrapper class which integrate Pytorch AMP with an optimizer
- Parameters
optim (torch.optim.Optimizer) – A normal optimizer like Adam or SGD.
init_scale (float, optional, default=2.**16) – Initial scale factor.
growth_factor (float, optional, default=2.0) – Factor by which the scale is multiplied during
update()if no inf/NaN gradients occur forgrowth_intervalconsecutive iterations.backoff_factor (float, optional, default=0.5) – Factor by which the scale is multiplied during
update()if inf/NaN gradients occur in an iteration.growth_interval (int, optional, default=2000) – Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by
growth_factor.enabled (bool, optional, default=True) – If
False, disables gradient scaling.step()simply invokes the underlyingoptimizer.step(), and other methods become no-ops.