colossalai.nn.loss.loss_moe

class colossalai.nn.loss.loss_moe.MoeCrossEntropyLoss(aux_weight=0.01, *args, **kwargs)[source]

torch.nn.CrossEntropyLoss added with auxiliary loss.

Parameters
  • input (torch.tensor) – Predicted unnormalized scores (often referred to as logits).

  • target (torch.tensor) – Ground truth class indices or class probabilities.

  • aux_weight (float, optional) – Weight of auxiliary loss in total loss.Defaults 0.01.

The args and kwargs should include parameters below:

weight (Tensor, optional)
size_average (bool, optional)
ignore_index (int, optional)
reduce (bool, optional)
reduction (str, optional)
label_smoothing (float, optional)

More details about args, kwargs and torch.nn.functional.cross_entropy could be found in Cross_entropy.

forward(*args)[source]

The args should at least include parameters below:

input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.

More details about args, kwargs and torch.nn.functional.cross_entropy could be found in Cross_entropy.

class colossalai.nn.loss.loss_moe.MoeLoss(aux_weight, loss_fn, *args, **kwargs)[source]

A wrapper class for any loss module to add with auxiliary loss.

Parameters
  • aux_weight (float) – Weight of auxiliary loss in total loss.

  • loss_fn (Callable) – Loss function.

  • args (list) – Args in loss function.

  • kwargs (dict) – Kwargs in loss function

forward(*args, **kwargs)[source]

The args and kwargs should at least include parameters below:

input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.

Note

The args and kwargs may include different parameters varying with different loss function.