colossalai.nn.loss.loss_1d

class colossalai.nn.loss.loss_1d.VocabParallelCrossEntropyLoss1D(reduction=True)[source]

Vocab parallel cross entropy loss for 1D parallelism.

Parameters

reduction (bool, optional) – whether to average the loss, defaults to True.

forward(logits, targets)[source]

Calculate loss between logits and targets.

Parameters
  • logits (torch.tensor) – Predicted unnormalized scores (often referred to as logits).

  • targets (torch.tensor) – Ground truth class indices or class probabilities.