colossalai.nn.loss.loss_moe
- class colossalai.nn.loss.loss_moe.MoeCrossEntropyLoss(aux_weight=0.01, *args, **kwargs)
torch.nn.CrossEntropyLoss added with auxiliary loss.
- Parameters
aux_weight (float, optional) – Weight of auxiliary loss in total loss
args – Args in CrossEntropyLoss
kwargs – Kwargs in CrossEntropyLoss
- class colossalai.nn.loss.loss_moe.MoeLoss(aux_weight, loss_fn, *args, **kwargs)
A wrapper class for any loss module to add with auxiliary loss.
- Parameters
aux_weight (float) – Weight of auxiliary loss in total loss
loss_fn (Callable) – Loss function
args – Args in loss function
kwargs – Kwargs in loss function