Skip to content

flexcv.metrics

This module implements the MetricsDict class and it's default metrics. It is used to specify which metrics to calculate in the outer loop of the cross-validation.

flexcv.metrics.MetricsDict

Bases: dict

A dictionary that maps metric names to functions. It can be passed to the cross_validate function to specify which metrics to calculate in the outer loop.

Default Metrics

By default, the following metrics are initialized:

  • R²: The coefficient of determination

  • MSE: Mean squared error

  • MAE: Mean absolute error

We decided againt using the RMSE as a default metric, because we would run into trouble wherever we would average over it. RMSE should always be averaged as sqrt(mean(MSE_values)) and not as mean(sqrt(MSE_values)). Also, the standard deviation would be calculated incorrectly if RMSE is included at this point.

Parameters:

Name Type Description Default
(dict)

A dictionary that maps metric names to functions.

required
Example
from flexcv.metrics import MetricsDict

def naive_metric(valid, pred):
    return 42

# instantiate a MetricsDict with the default metrics R², MSE and MAE
metrics = MetricsDict()
# add a custom metric
metrics["naive_metric"] = naive_metric
Source code in flexcv/metrics.py
class MetricsDict(dict):
    """A dictionary that maps metric names to functions.
    It can be passed to the cross_validate function to specify which metrics to calculate in the outer loop.

    Default Metrics:
        By default, the following metrics are initialized:

        - R²: The coefficient of determination

        - MSE: Mean squared error

        - MAE: Mean absolute error

        We decided againt using the RMSE as a default metric, because we would run into trouble wherever we would average over it.
        RMSE should always be averaged as `sqrt(mean(MSE_values))` and not as `mean(sqrt(MSE_values))`.
        Also, the standard deviation would be calculated incorrectly if RMSE is included at this point.

    Parameters:
        (dict): A dictionary that maps metric names to functions.

    Example:
        ```python
        from flexcv.metrics import MetricsDict

        def naive_metric(valid, pred):
            return 42

        # instantiate a MetricsDict with the default metrics R², MSE and MAE
        metrics = MetricsDict()
        # add a custom metric
        metrics["naive_metric"] = naive_metric
        ```
    """

    def __init__(self):
        super().__init__()
        self["r2"] = r2_score
        self["mse"] = mean_squared_error
        self["mae"] = mean_absolute_error

flexcv.metrics.mse_wrapper(y_valid, y_pred, y_train_in, y_pred_train)

This function is only used to calculate the objective function value for the hyperparameter optimization. In order to allow for customized objective functions it takes the validation and training data and the corresponding predictions as arguments. This can be useful to avoid overfitting. The sklearn MSE function had to be wrapped accordingly

Parameters:

Name Type Description Default
y_valid array - like

Target in the validation set.

required
y_pred array - like

Predictions for the validation set.

required
y_train_in array - like

Target in the training set.

required
y_pred_train array - like

Predictions for the training set.

required

Returns:

Type Description
float

Mean squared error.

Source code in flexcv/metrics.py
def mse_wrapper(y_valid, y_pred, y_train_in, y_pred_train):
    """This function is only used to calculate the objective function value for the hyperparameter optimization.
    In order to allow for customized objective functions it takes the validation and training data and the corresponding predictions as arguments.
    This can be useful to avoid overfitting. The sklearn MSE function had to be wrapped accordingly

    Args:
      y_valid (array-like): Target in the validation set.
      y_pred (array-like): Predictions for the validation set.
      y_train_in (array-like): Target in the training set.
      y_pred_train (array-like): Predictions for the training set.

    Returns:
        (float): Mean squared error.

    """
    return mean_squared_error(y_valid, y_pred)