colossalai.amp.torch_amp.torch_amp
- class colossalai.amp.torch_amp.torch_amp.TorchAMPOptimizer(optim, *args, **kwargs)
A wrapper class which integrate pytorch amp with an optimizer
- Parameters
optim (torch.optim.Optimizer) – A normal optimizer like Adam or SGD
args – Args used to initialize gradient scaler
kwargs – Kwargs used to initialize gradient scaler
- backward(loss)
Backward with torch amp gradient scaler
- Parameters
loss (torch.Tensor) – Loss computed by a loss function
- step()
Update the parameters of the model
- clip_grad_norm(model, max_norm)
Apply gradient clipping to the model parameters
- Parameters
model (torch.nn.Module) – Your model object
max_norm (float) – Max norm value for gradient clipping
- class colossalai.amp.torch_amp.torch_amp.TorchAMPModel(model)
A wrapper class for a model object which executes forward with values automatically cast to fp16
- class colossalai.amp.torch_amp.torch_amp.TorchAMPLoss(loss)
A wrapper class for a criterion object which computes the loss in mixed-precision context
- Parameters
loss (torch.nn.modules.loss._Loss) – A loss function object