Source code for probflow.utils.metrics

"""Metrics.

Evaluation metrics

* :func:`.log_prob`
* :func:`.acc`
* :func:`.accuracy`
* :func:`.mse`
* :func:`.sse`
* :func:`.mae`

----------

"""


__all__ = [
    "accuracy",
    "mean_squared_error",
    "sum_squared_error",
    "mean_absolute_error",
    "r_squared",
    "true_positive_rate",
    "true_negative_rate",
    "precision",
    "f1_score",
    "get_metric_fn",
]


import numpy as np
import pandas as pd


def as_numpy(fn):
    """Cast inputs to numpy arrays and same shape before computing metric"""

    def metric_fn(y_true, y_pred):

        # Cast y_true to numpy array
        if isinstance(y_true, (pd.Series, pd.DataFrame)):
            new_y_true = y_true.values
        elif isinstance(y_true, np.ndarray):
            new_y_true = y_true
        else:
            new_y_true = y_true.numpy()

        # Cast y_pred to numpy array
        if isinstance(y_pred, (pd.Series, pd.DataFrame)):
            new_y_pred = y_pred.values
        elif isinstance(y_pred, np.ndarray):
            new_y_pred = y_pred
        else:
            new_y_pred = y_pred.numpy()

        # Ensure correct sizes
        if new_y_true.ndim == 1:
            new_y_true = np.expand_dims(new_y_true, 1)
        if new_y_pred.ndim == 1:
            new_y_pred = np.expand_dims(new_y_pred, 1)

        # Return metric function on consistent arrays
        return fn(new_y_true, new_y_pred)

    return metric_fn


@as_numpy
def accuracy(y_true, y_pred):
    """Accuracy of predictions."""
    return np.mean(y_pred == y_true)


@as_numpy
def mean_squared_error(y_true, y_pred):
    """Mean squared error."""
    return np.mean(np.square(y_true - y_pred))


@as_numpy
def sum_squared_error(y_true, y_pred):
    """Sum of squared error."""
    return np.sum(np.square(y_true - y_pred))


@as_numpy
def mean_absolute_error(y_true, y_pred):
    """Mean absolute error."""
    return np.mean(np.abs(y_true - y_pred))


@as_numpy
def r_squared(y_true, y_pred):
    """Coefficient of determination."""
    ss_tot = np.sum(np.square(y_true - np.mean(y_true)))
    ss_res = np.sum(np.square(y_true - y_pred))
    return 1.0 - ss_res / ss_tot


@as_numpy
def true_positive_rate(y_true, y_pred):
    """True positive rate aka sensitivity aka recall."""
    p = np.sum(y_true == 1)
    tp = np.sum((y_pred == y_true) & (y_true == 1))
    return tp / p


@as_numpy
def true_negative_rate(y_true, y_pred):
    """True negative rate aka specificity aka selectivity."""
    n = np.sum(y_true == 0)
    tn = np.sum((y_pred == y_true) & (y_true == 0))
    return tn / n


@as_numpy
def precision(y_true, y_pred):
    """Precision."""
    ap = np.sum(y_pred)
    tp = np.sum((y_pred == y_true) & (y_true == 1))
    return tp / ap


@as_numpy
def f1_score(y_true, y_pred):
    """F-measure."""
    p = precision(y_true, y_pred)
    r = true_positive_rate(y_true, y_pred)
    return 2 * (p * r) / (p + r)


# TODO: jaccard_similarity


# TODO: roc_auc


# TODO: cross_entropy


[docs]def get_metric_fn(metric): """Get a function corresponding to a metric string""" # List of valid metric strings metrics = { "accuracy": accuracy, "acc": accuracy, "mean_squared_error": mean_squared_error, "mse": mean_squared_error, "sum_squared_error": sum_squared_error, "sse": sum_squared_error, "mean_absolute_error": mean_absolute_error, "mae": mean_absolute_error, "r_squared": r_squared, "r2": r_squared, "recall": true_positive_rate, "sensitivity": true_positive_rate, "true_positive_rate": true_positive_rate, "tpr": true_positive_rate, "specificity": true_negative_rate, "selectivity": true_negative_rate, "true_negative_rate": true_negative_rate, "tnr": true_negative_rate, "precision": precision, "f1_score": f1_score, "f1": f1_score, # 'jaccard_similarity': jaccard_similarity, # 'jaccard': jaccard_similarity, # 'roc_auc': roc_auc, # 'auroc': roc_auc, # 'auc': roc_auc, } # Return the corresponding function if callable(metric): return metric elif isinstance(metric, str): if metric not in metrics: raise ValueError( metric + " is not a valid metric string. " + "Valid strings are: " + ", ".join(metrics.keys()) ) else: return metrics[metric] else: raise TypeError("metric must be a str or callable")