colossalai.nn.loss.loss_moe

class colossalai.nn.loss.loss_moe.MoeCrossEntropyLoss(aux_weight=0.01, *args, **kwargs)

torch.nn.CrossEntropyLoss added with auxiliary loss.

Parameters
  • aux_weight (float, optional) – Weight of auxiliary loss in total loss

  • args – Args in CrossEntropyLoss

  • kwargs – Kwargs in CrossEntropyLoss

class colossalai.nn.loss.loss_moe.MoeLoss(aux_weight, loss_fn, *args, **kwargs)

A wrapper class for any loss module to add with auxiliary loss.

Parameters
  • aux_weight (float) – Weight of auxiliary loss in total loss

  • loss_fn (Callable) – Loss function

  • args – Args in loss function

  • kwargs – Kwargs in loss function