colossalai.amp
- colossalai.amp.convert_to_amp(model, optimizer, criterion, mode, amp_config=None)
A helper function to wrap training components with Torch AMP modules
- Parameters
model (
torch.nn.Module) – your model objectoptimizer (
torch.optim.Optimizer) – your optimizer objectcriterion (
torch.nn.modules.loss._Loss) – your loss function objectmode (
colossalai.amp.AMP_TYPE) – amp modeamp_config (
colossalai.context.Configor dict) – configuration for different amp modes
- Returns
(model, optimizer, criterion)
- Return type
Tuple