colossalai.amp.naive_amp.naive_amp

class colossalai.amp.naive_amp.naive_amp.NaiveAMPOptimizer(optim, *args, **kwargs)

A wrapper class for optimizer to cast all parameters to fp16

Parameters
  • optim (torch.optim.Optimizer) – A normal optimizer like Adam or SGD

  • args – Args used to initialize FP16 optimizer

  • kwargs – Kwargs used to initialize FP16 optimizer

class colossalai.amp.naive_amp.naive_amp.NaiveAMPModel(model, output_to_fp32=True, parallel_mode=ParallelMode.DATA, sync_buffer=True)

A wrapper class for model to cast the model into fp16 and automatically cast the input and output