colossalai.amp

colossalai.amp.convert_to_amp(model, optimizer, criterion, mode, amp_config=None)

A helper function to wrap training components with Torch AMP modules

Parameters
  • model (torch.nn.Module) – your model object

  • optimizer (torch.optim.Optimizer) – your optimizer object

  • criterion (torch.nn.modules.loss._Loss) – your loss function object

  • mode (colossalai.amp.AMP_TYPE) – amp mode

  • amp_config (colossalai.context.Config or dict) – configuration for different amp modes

Returns

(model, optimizer, criterion)

Return type

Tuple