colossalai.amp.torch_amp
- colossalai.amp.torch_amp.convert_to_torch_amp(model, optimizer, criterion=None, amp_config=None)
A helper function to wrap training components with Torch AMP modules
- Parameters
model (
torch.nn.Module) – your model objectoptimizer (
torch.optim.Optimizer) – your optimizer objectcriterion (
torch.nn.modules.loss._Loss, optional) – your loss function objectamp_config (
colossalai.context.Configor dict, optional) – configuration for different amp modes
- Returns
(model, optimizer, criterion)
- Return type
Tuple
- class colossalai.amp.torch_amp.TorchAMPModel(model)
A wrapper class for a model object which executes forward with values automatically cast to fp16
- class colossalai.amp.torch_amp.TorchAMPLoss(loss)
A wrapper class for a criterion object which computes the loss in mixed-precision context
- Parameters
loss (torch.nn.modules.loss._Loss) – A loss function object
- class colossalai.amp.torch_amp.TorchAMPOptimizer(optim, *args, **kwargs)
A wrapper class which integrate pytorch amp with an optimizer
- Parameters
optim (torch.optim.Optimizer) – A normal optimizer like Adam or SGD
args – Args used to initialize gradient scaler
kwargs – Kwargs used to initialize gradient scaler
- backward(loss)
Backward with torch amp gradient scaler
- Parameters
loss (torch.Tensor) – Loss computed by a loss function
- step()
Update the parameters of the model
- clip_grad_norm(model, max_norm)
Apply gradient clipping to the model parameters
- Parameters
model (torch.nn.Module) – Your model object
max_norm (float) – Max norm value for gradient clipping