Source code for pyblaze.nn.functional.metrics

import torch
import sklearn.metrics as metrics

[docs]def accuracy(y_pred, y_true): """ Computes the accuracy of the class predictions. Parameters ---------- y_pred: torch.LongTensor [N] or torch.FloatTensor [N, C] The class predictions made by the model. Can be either specific classes or predictions for each class. y_true: torch.LongTensor [N] or torch.FloatTensor [N, C] The actual classes, either given as indices or one-hot vectors (more specifically, it may be any vector whose row-wise argmax values yield the class labels). Returns ------- torch.FloatTensor The accuracy. """ y_pred = _ensure_classes(y_pred) y_true = _ensure_classes(y_true) return (y_pred == y_true).float().mean()
[docs]def recall(y_pred, y_true, c=1): """ Computes the recall score of the class predictions. Parameters ---------- y_pred: torch.LongTensor [N] or torch.FloatTensor [N, C] The class predictions made by the model. Can be either specific classes or predictions for each class. y_true: torch.LongTensor [N] The actual classes. c: int, default: 1 The class to calculate the recall score for. Default assumes a binary classification setting. Returns ------- torch.FloatTensor The recall score. """ y_pred = _ensure_classes(y_pred) y_pred = y_pred == c y_true = y_true == c correct = (y_pred[y_true == y_pred]).sum() true_correct = y_true.sum() return correct.float() / true_correct.float()
[docs]def precision(y_pred, y_true, c=1): """ Computes the precision score of the class predictions. Parameters ---------- y_pred: torch.LongTensor [N] or torch.FloatTensor [N, C] The class predictions made by the model. Can be either specific classes or predictions for each class. y_true: torch.LongTensor [N] The actual classes. c: int, default: 1 The class to calculate the recall score for. Default assumes a binary classification setting. Returns ------- torch.FloatTensor The precision score. """ y_pred = _ensure_classes(y_pred) y_pred = y_pred == c y_true = y_true == c correct = (y_pred[y_true == y_pred]).sum() true_correct = y_pred.sum() return correct.float() / true_correct.float()
[docs]def f1_score(y_pred, y_true, c=1): """ Computes the F1-score of the class predictions. Parameters ---------- y_pred: torch.LongTensor [N] or torch.FloatTensor [N, C] The class predictions made by the model. Can be either specific classes or predictions for each class. y_true: torch.LongTensor [N] The actual classes. c: int, default: 1 The class to calculate the recall score for. Default assumes a binary classification setting. Returns ------- torch.FloatTensor The F1-score. """ y_pred = _ensure_classes(y_pred) p = precision(y_pred, y_true, c=c) r = recall(y_pred, y_true, c=c) return (2 * p * r) / (p + r)
[docs]def roc_auc_score(y_pred, y_true): """ Computes the area under the ROC curve. Parameters ---------- y_pred: torch.FloatTensor [N] The (binary) predictions made by the model. y_true: torch.LongTensor [N] The actual classes. Returns ------- torch.FloatTensor The ROC-AUC score. """ assert y_pred.dim() == 1, \ "ROC-AUC score only works in the binary case." return torch.as_tensor( metrics.roc_auc_score(y_true.numpy(), y_pred.numpy()) )
def pr_auc_score(y_pred, y_true): """ Computes the area under the precision-recall curve. Parameters ---------- y_pred: torch.FloatTensor [N] The (binary) predictions made by the model. y_true: torch.LongTensor [N] The actual classes. Returns ------- torch.FloatTensor The PR-AUC score. """ prec, rec, _ = metrics.precision_recall_curve(y_true.numpy(), y_pred.numpy()) auc = metrics.auc(rec, prec) return torch.as_tensor(auc)
[docs]def average_precision(y_pred, y_true): """ Computes the average precision of the model predictions. Parameters ---------- y_pred: torch.FloatTensor [N] The (binary) predictions made by the model. y_trye: torch.LongTensor [N] The actual classes. Returns ------- torch.FloatTensor The average precision. """ assert y_pred.dim() == 1, \ "Average precision only work in the binary case." return torch.as_tensor( metrics.average_precision_score(y_true.numpy(), y_pred.numpy()) )
def _ensure_classes(y): if y.dim() < 2: return y if y.dim() == 2: return torch.argmax(y, dim=-1) if y.dtype == torch.float: return torch.round(y) return y