colossalai.nn.loss.loss_moe
- class colossalai.nn.loss.loss_moe.MoeCrossEntropyLoss(aux_weight=0.01, *args, **kwargs)[source]
torch.nn.CrossEntropyLoss added with auxiliary loss.
- Parameters
input (
torch.tensor) – Predicted unnormalized scores (often referred to as logits).target (
torch.tensor) – Ground truth class indices or class probabilities.aux_weight (float, optional) – Weight of auxiliary loss in total loss.Defaults 0.01.
The
argsandkwargsshould include parameters below:weight (Tensor, optional) size_average (bool, optional) ignore_index (int, optional) reduce (bool, optional) reduction (str, optional) label_smoothing (float, optional)
More details about
args,kwargsandtorch.nn.functional.cross_entropycould be found in Cross_entropy.- forward(*args)[source]
The
argsshould at least include parameters below:input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
More details about
args,kwargsandtorch.nn.functional.cross_entropycould be found in Cross_entropy.
- class colossalai.nn.loss.loss_moe.MoeLoss(aux_weight, loss_fn, *args, **kwargs)[source]
A wrapper class for any loss module to add with auxiliary loss.
- Parameters
aux_weight (float) – Weight of auxiliary loss in total loss.
loss_fn (
Callable) – Loss function.args (list) – Args in loss function.
kwargs (dict) – Kwargs in loss function
- forward(*args, **kwargs)[source]
The
argsandkwargsshould at least include parameters below:input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
Note
The
argsandkwargsmay include different parameters varying with different loss function.