colossalai.nn.metric.accuracy_2d

class colossalai.nn.metric.accuracy_2d.Accuracy2D[source]

Accuracy for 2D parallelism

forward(logits, targets)[source]

Calculate the accuracy of predicted labels.

Parameters
  • logits (torch.tensor) – Predicted labels.

  • targets (torch.tensor) – True labels from data.

Returns

the accuracy of prediction.

Return type

float