""" Contains two class classification metrics """
from copy import copy
from functools import partial
import numpy as np
from ...decorators import mjit
from . import Metrics, binarize, sigmoid, infmean
METRICS_ALIASES = {'sensitivity' : 'true_positive_rate',
'recall' : 'true_positive_rate',
'tpr' : 'true_positive_rate',
'fallout' : 'false_positive_rate',
'fpr' : 'false_positive_rate',
'miss_rate' : 'false_negative_rate',
'fnr' : 'false_negative_rate',
'specificity' : 'true_negative_rate',
'tnr' : 'true_negative_rate',
'prv' : 'prevalence',
'acc' : 'accuracy',
'precision' : 'positive_predictive_value',
'ppv' : 'positive_predictive_value',
'fdr' : 'false_discovery_rate',
'for' : 'false_omission_rate',
'npv' : 'negative_predictive_value',
'plr' : 'positive_likelihood_ratio',
'nlr' : 'negative_likelihood_ratio',
'dor' : 'diagnostics_odds_ratio',
'dice' : 'f1_score',
'f1s' : 'f1_score',
'iou' : 'jaccard',
'jac' : 'jaccard'}
[docs]class ClassificationMetrics(Metrics):
""" Metrics to assess classification models
Parameters
----------
targets : np.array
Ground-truth labels / probabilities / logits
predictions : np.array
Predicted labels / probabilites / logits
num_classes : int
the number of classes (default is None)
fmt : 'proba', 'logits', 'labels'
whether arrays contain probabilities, logits or labels
axis : int
a class axis (default is None)
threshold : float
A probability level for binarization (lower values become 0, equal or greater values become 1)
Notes
-----
- Input arrays (`targets` and `predictions`) might be vectors or multidimensional arrays,
where the first dimension represents batch items. The latter is useful for pixel-level metrics.
- Both `targets` and `predictions` usually contain the same data (labels, probabilities or logits).
However, `targets` might be labels, while `predictions` are probabilities / logits.
For that to work:
- `targets` should have the shape which exactly 1 dimension smaller, than `predictions` shape;
- `axis` should point to that dimension;
- `fmt` should contain format of `predictions`.
- When `axis` is specified, `predictions` should be a one-hot array with class information provided
in the given axis (class probabilities or logits). In this case `targets` can contain labels (see above)
or probabilities / logits in the very same axis.
- If `fmt` is 'labels', `num_classes` should be specified. Due to randomness any given batch may not
contain items of some classes, so all the labels cannot be inferred as simply as `labels.max()`.
- If `fmt` is 'proba' or 'logits', then `axis` points to the one-hot dimension.
However, if `axis` is None, two class classification is assumed and `targets` / `predictions`
should contain probabilities or logits for a positive class only.
**Metrics**
All metrics return:
- a single value if input is a vector for a 2-class task.
- a single value if input is a vector for a multiclass task and multiclass averaging is enabled.
- a vector with batch size items if input is a multidimensional array (e.g. images or sequences)
and there are just 2 classes or multiclass averaging is on.
- a vector with `num_classes` items if input is a vector for multiclass case without averaging.
- a 2d array `(batch_items, num_classes)` for multidimensional inputs in a multiclass case without averaging.
.. note:: Count-based metrics (`true_positive`, `false_positive`, etc.) do not support mutliclass averaging.
They always return counts for each class separately.
For multiclass tasks rate metrics, such as `true_positive_rate`, `false_positive_rate`, etc.,
might seem more convenient.
**Multiclass metrics**
In a multiclass case metrics might be calculated with or without class averaging.
Available methods are:
- `None` - no averaging, calculate metrics for each class individually (one-vs-all)
- `'micro'` - calculate metrics globally by counting the total true positives,
false negatives, false positives, etc. across all classes
- `'macro'` - calculate metrics for each class, and take their mean.
Examples
--------
::
metrics = ClassificationMetrics(targets, predictions, num_classes=10, fmt='labels')
metrics.evaluate(['sensitivity', 'specificity'], multiclass='macro')
"""
def __init__(self, targets, predictions, fmt='proba', num_classes=None, axis=None, threshold=.5,
skip_bg=False, calc=True):
super().__init__()
self.targets = None
self.predictions = None
self._confusion_matrix = None
self.skip_bg = skip_bg
self.num_classes = None if axis is None else predictions.shape[axis]
self.num_classes = self.num_classes or num_classes or 2
self._agg_fn_dict = {'mean': partial(infmean, axis=0)}
if fmt in ['proba', 'logits'] and axis is None and self.num_classes > 2:
raise ValueError('axis cannot be None for multiclass case when fmt is proba or logits')
if targets.ndim == predictions.ndim:
# targets and predictions contain the same info (labels, probabilities or logits)
targets = self._to_labels(targets, fmt, axis, threshold)
elif targets.ndim == predictions.ndim - 1 and fmt != 'labels':
# targets contains labels while predictions is a one-hot array
pass
else:
raise ValueError("targets and predictions should have compatible shapes",
targets.shape, predictions.shape)
predictions = self._to_labels(predictions, fmt, axis, threshold)
if targets.ndim == 1:
targets = targets.reshape(1, -1)
predictions = predictions.reshape(1, -1)
self._no_zero_axis = True
else:
self._no_zero_axis = False
self.targets = targets
self.predictions = predictions
if calc:
self._calc()
def __getattr__(self, name):
if name == "METRICS_ALIASES":
raise AttributeError # See https://nedbatchelder.com/blog/201010/surprising_getattr_recursion.html
name = METRICS_ALIASES.get(name, name)
return object.__getattribute__(self, name)
@property
def confusion_matrix(self):
return self._confusion_matrix.sum(axis=0)
[docs] def plot_confusion_matrix(self, classes=None, normalize=False, **kwargs):
""" Plot confusion matrix.
Parameters
----------
classes : sequence, optional
Sequence of classes labels.
normalize : bool
Whether to normalize confusion matrix over target classes.
"""
confusion_matrix = np.array(self.confusion_matrix)
if classes is None:
classes = np.arange(self.num_classes)
if normalize:
confusion_matrix = confusion_matrix / np.nansum(confusion_matrix, axis=0)
plot_config = {
'title': 'Normalized confusion matrix' if normalize else 'Confusion matrix',
'xlabel': 'Actual class',
'xtick_locations': np.arange(confusion_matrix.shape[0]),
'xtick_labels': classes,
'xtick_rotation': 90,
'xtick_ha': 'center',
'ylabel': 'Predicted class',
'ytick_locations': np.arange(confusion_matrix.shape[1]),
'ytick_labels': classes,
'ytick_rotation': 0,
'ytick_va': 'center',
**kwargs
}
from ...plotter import plot
return plot(data=confusion_matrix, mode='matrix', **plot_config)
[docs] def copy(self):
""" Return a duplicate containing only the confusion matrix """
metrics = copy(self)
metrics.free()
return metrics
def _to_labels(self, arr, fmt, axis, threshold):
if fmt == 'labels':
pass
elif fmt in ['proba', 'logits']:
if axis is None:
if fmt == 'logits':
arr = sigmoid(arr)
arr = binarize(arr, threshold).astype('int8')
else:
arr = arr.argmax(axis=axis)
return arr
[docs] def one_hot(self, inputs):
""" Convert an array of labels into a one-hot array """
return np.eye(self.num_classes)[inputs] if self.num_classes > 2 else inputs
[docs] def free(self):
""" Free memory allocated for intermediate data """
self.targets = None
self.predictions = None
[docs] def append(self, metrics):
""" Append confusion matrix with data from another metrics"""
# pylint: disable=protected-access
self._confusion_matrix = np.concatenate((self._confusion_matrix, metrics._confusion_matrix), axis=0)
[docs] def update(self, metrics):
""" Update confusion matrix with data from another metrics"""
# pylint: disable=protected-access
if self._no_zero_axis:
self._confusion_matrix = self._confusion_matrix + metrics._confusion_matrix
else:
self._confusion_matrix = np.concatenate((self._confusion_matrix, metrics._confusion_matrix), axis=0)
def __getitem__(self, item):
# pylint: disable=protected-access
metrics = self.copy()
metrics._confusion_matrix = metrics._confusion_matrix[item]
return metrics
def _calc(self):
self._confusion_matrix = np.zeros((self.targets.shape[0], self.num_classes, self.num_classes), dtype=np.intp)
return self._calc_confusion_jit(self.targets, self.predictions, self.num_classes, self._confusion_matrix)
@mjit
def _calc_confusion_jit(self, targets, predictions, num_classes, confusion):
for i in range(targets.shape[0]):
targ = targets[i].flatten()
pred = predictions[i].flatten()
for t in range(num_classes):
coords = np.where(targ == t)
for c in pred[coords]:
confusion[i, c, t] += 1
def _return(self, value):
return value[0] if isinstance(value, np.ndarray) and value.shape == (1, ) else value
def _all_labels(self):
first = 1 if self.skip_bg else 0
labels = list(range(first, self.num_classes))
return labels
def _count(self, f, label=None):
if label is None:
label = self._all_labels() if self.num_classes > 2 else 1
if np.isscalar(label):
return self._return(f(label))
return np.array([self._return(f(l)) for l in label]).T
[docs] def true_positive(self, label=None, *args, **kwargs):
_ = args, kwargs
return self._count(lambda l: self._confusion_matrix[:, l, l], label)
[docs] def false_positive(self, label=None, *args, **kwargs):
_ = args, kwargs
return self._count(lambda l: self.prediction_positive(l) - self.true_positive(l), label)
[docs] def true_negative(self, label=None, *args, **kwargs):
_ = args, kwargs
return self._count(lambda l: self.condition_negative(l) - self.false_positive(l), label)
[docs] def false_negative(self, label=None, *args, **kwargs):
_ = args, kwargs
return self._count(lambda l: self.condition_positive(l) - self.true_positive(l), label)
[docs] def condition_positive(self, label=None, *args, **kwargs):
_ = args, kwargs
return self._count(lambda l: self._confusion_matrix[:, :, l].sum(axis=1), label)
[docs] def condition_negative(self, label=None, *args, **kwargs):
_ = args, kwargs
return self._count(lambda l: self.total_population(l) - self.condition_positive(l), label)
[docs] def prediction_positive(self, label=None, *args, **kwargs):
_ = args, kwargs
return self._count(lambda l: self._confusion_matrix[:, l].sum(axis=1), label)
[docs] def prediction_negative(self, label=None, *args, **kwargs):
_ = args, kwargs
return self._count(lambda l: self.total_population(l) - self.prediction_positive(l), label)
[docs] def total_population(self, *args, **kwargs):
_ = args, kwargs
return self._return(self._confusion_matrix.sum(axis=(1, 2)))
def _calc_agg(self, numer, denom, label=None, multiclass='macro', when_zero=None):
_when_zero = lambda n: np.where(n > 0, when_zero[0], when_zero[1]).astype(float)
if self.num_classes == 2:
label = label if label is not None else 1
labels = label if label is not None else self._all_labels()
labels = labels if isinstance(labels, (list, tuple)) else [labels]
fractions = [(numer(l).astype(float), denom(l).astype(float)) for l in labels]
if multiclass is None:
value = [np.divide(n, d, out=_when_zero(n), where=(d > 0)).ravel() for n, d in fractions]
classes_calculated = self.num_classes - 1 if self.skip_bg else self.num_classes
value = value[0] if len(value) == 1 else np.array(value).T.reshape(-1, classes_calculated)
elif multiclass == 'micro':
n = np.sum([f[0] for f in fractions], axis=0)
d = np.sum([f[1] for f in fractions], axis=0)
value = np.divide(n, d, out=_when_zero(n), where=(d > 0)).reshape(-1, 1)
elif multiclass in ['macro', 'mean']:
value = [np.divide(n, d, out=_when_zero(n), where=(d > 0)) for n, d in fractions]
value = infmean(value, axis=0).reshape(-1, 1)
return value
[docs] def true_positive_rate(self, *args, when_zero=(0, 1), **kwargs):
return self._calc_agg(self.true_positive, self.condition_positive, *args, when_zero=when_zero, **kwargs)
[docs] def false_positive_rate(self, *args, when_zero=(1, 0), **kwargs):
return self._calc_agg(self.false_positive, self.condition_negative, *args, when_zero=when_zero, **kwargs)
[docs] def false_negative_rate(self, *args, when_zero=(1, 0), **kwargs):
return self._calc_agg(self.false_negative, self.condition_positive, *args, when_zero=when_zero, **kwargs)
[docs] def true_negative_rate(self, *args, when_zero=(0, 1), **kwargs):
return self._calc_agg(self.true_negative, self.condition_negative, *args, when_zero=when_zero, **kwargs)
[docs] def prevalence(self, *args, when_zero=(0, 0), **kwargs):
"""
Notes
-----
Parameter when_zero doesn't really matter in this case,
since total_population is never zero, when targets are not empty.
"""
return self._calc_agg(self.condition_positive, self.total_population, *args, when_zero=when_zero, **kwargs)
[docs] def accuracy(self):
""" An accuracy of detecting all the classes combined """
return np.sum([self.true_positive(l) for l in self._all_labels()], axis=0) / self.total_population()
[docs] def positive_predictive_value(self, *args, when_zero=(0, 1), **kwargs):
return self._calc_agg(self.true_positive, self.prediction_positive, *args, when_zero=when_zero, **kwargs)
[docs] def false_discovery_rate(self, *args, when_zero=(1, 0), **kwargs):
return self._calc_agg(self.false_positive, self.prediction_positive, *args, when_zero=when_zero, **kwargs)
[docs] def false_omission_rate(self, *args, when_zero=(1, 0), **kwargs):
return self._calc_agg(self.false_negative, self.prediction_negative, *args, when_zero=when_zero, **kwargs)
[docs] def negative_predictive_value(self, *args, when_zero=(0, 1), **kwargs):
return self._calc_agg(self.true_negative, self.prediction_negative, *args, when_zero=when_zero, **kwargs)
[docs] def positive_likelihood_ratio(self, *args, when_zero=(np.inf, 0), **kwargs):
return self._calc_agg(self.true_positive_rate, self.false_positive_rate, *args, when_zero=when_zero, **kwargs)
[docs] def negative_likelihood_ratio(self, *args, when_zero=(np.inf, 0), **kwargs):
return self._calc_agg(self.false_negative_rate, self.true_negative_rate, *args, when_zero=when_zero, **kwargs)
[docs] def diagnostics_odds_ratio(self, *args, when_zero=(np.inf, 0), **kwargs):
return self._calc_agg(self.positive_likelihood_ratio, self.negative_likelihood_ratio,
*args, when_zero=when_zero, **kwargs)
[docs] def f1_score(self, *args, **kwargs):
""" Compute f1-score """
recall = self.recall(*args, when_zero=(0, np.inf), **kwargs)
precision = self.precision(*args, when_zero=(0, np.inf), **kwargs)
mask = np.isinf(recall) & np.isinf(precision)
value = np.nan_to_num(2 * (recall * precision) / (recall + precision))
value[mask] = np.inf
return value
[docs] def jaccard(self, *args, **kwargs):
d = self.dice(*args, **kwargs)
return np.nan_to_num(d / (2 - d), nan=np.inf)