colossalai.amp.naive_amp

colossalai.amp.naive_amp.convert_to_naive_amp(model, optimizer, amp_config)

A helper function to wrap training components with naive AMP modules

Parameters
  • model (torch.nn.Module) – your model object

  • optimizer (torch.optim.Optimizer) – your optimizer object

  • amp_config (colossalai.context.Config or dict) – configuration for naive mode amp

Returns

(model, optimizer)

Return type

Tuple

class colossalai.amp.naive_amp.NaiveAMPOptimizer(optim, *args, **kwargs)

A wrapper class for optimizer to cast all parameters to fp16

Parameters
  • optim (torch.optim.Optimizer) – A normal optimizer like Adam or SGD

  • args – Args used to initialize FP16 optimizer

  • kwargs – Kwargs used to initialize FP16 optimizer