colossalai.amp.naive_amp

colossalai.amp.naive_amp.convert_to_naive_amp(model, optimizer, amp_config)[source]

A helper function to wrap training components with naive AMP modules. In this mode, we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss, which is equivalent to Apex O3.

Parameters
  • model (torch.nn.Module) – your model object

  • optimizer (torch.optim.Optimizer) – your optimizer object

  • amp_config (colossalai.context.Config or dict) – configuration for naive mode amp.

The amp_config should contain parameters below:

verbose (bool, optional): if set to `True`, will print debug info (Default: False).
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
                                  Note that clipping is ignored if clip_grad == 0.
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
Returns

A tuple (model, optimizer)

Return type

Tuples

class colossalai.amp.naive_amp.NaiveAMPOptimizer(optim, *args, **kwargs)[source]

A wrapper class for optimizer to cast all parameters to fp16

Parameters
  • optim (torch.optim.Optimizer) – A normal optimizer like Adam or SGD.

  • grad_scaler (BaseGradScaler) – grad scaler for gradient chose in constant_grad_scaler or dynamic_grad_scaler.

  • clip_grad_norm (float, optional) – clip gradients with this global L2 norm. Default 0.

  • verbose (bool, optional) – if set to True, will print debug info. Default False.

Note

clipping is ignored if clip_grad_norm equals 0.

class colossalai.amp.naive_amp.FP16Optimizer(optimizer, grad_scaler, verbose=False, clip_grad_norm=0, dp_process_group=None, mp_process_group=None)[source]

Float16 optimizer for fp16 and bf16 data types.

Parameters
  • optimizer (torch.optim.Optimizer) – base optimizer such as Adam or SGD

  • grad_scaler (BaseGradScaler) – grad scaler for gradient chose in constant_grad_scaler or dynamic_grad_scaler.

  • clip_grad_norm (float, optional) – clip gradients with this global L2 norm. Default 0. Note that clipping is ignored if clip_grad == 0

  • verbose (bool, optional) – if set to True, will print debug info. Default False.