Source code for cardio.models.dirichlet_model.dirichlet_model

"""Contains Dirichlet model class."""

from itertools import zip_longest

import numpy as np
import tensorflow as tf

from ..layers import conv1d_block, resnet1d_block
from ...dataset.dataset.models.tf import TFModel


def concatenate_ecg_batch(batch, model, return_targets=True):
    """Concatenate batch signals and (optionally) targets.

    Parameters
    ----------
    batch : EcgBatch
        Batch to concatenate.
    model : BaseModel
        A model to get the resulting arguments.
    return_targets : bool
        Specifies whether to return concatenated targets.

    Returns
    -------
    kwargs : dict
        Named argments for model's train or predict method. Has the following
        structure:
        "feed_dict" : dict
            "signals" : 3-D ndarray
                Concatenated signals.
            "targets" : 2-D ndarray, optional
                Concatenated targets.
        "split_indices" : 1-D ndarray
            Split indices to undo the concatenation.
    """
    _ = model
    x = np.concatenate(batch.signal)
    split_indices = np.cumsum([item.signal.shape[0] for item in batch])[:-1]
    res_dict = {"feed_dict": {"signals": x}, "split_indices": split_indices}
    if return_targets:
        y = np.concatenate([np.tile(item.target, (item.signal.shape[0], 1)) for item in batch])
        res_dict["feed_dict"]["targets"] = y
    return res_dict


[docs]class DirichletModelBase(TFModel): """Dirichlet model class. The model predicts Dirichlet distribution parameters from which class probabilities are sampled. Notes ----- **Configuration** Model config must contain the following keys: * input_shape : tuple Input signals's shape without the batch dimension. * class_names : array_like Class names. * loss : ``None`` The model has a predefined loss, so you should leave it ``None``. """ def _build(self, config=None): # pylint: disable=too-many-locals """Build Dirichlet model.""" input_shape = self.config["input_shape"] class_names = self.config["class_names"] with self: # pylint: disable=not-context-manager self.store_to_attr("class_names", tf.constant(class_names)) signals = tf.placeholder(tf.float32, shape=(None,) + input_shape, name="signals") self.store_to_attr("signals", signals) signals_channels_last = tf.transpose(signals, perm=[0, 2, 1], name="signals_channels_last") k = 0.001 targets = tf.placeholder(tf.float32, shape=(None, len(class_names)), name="targets") self.store_to_attr("targets", targets) targets_soft = (1 - 2 * k) * targets + k block = conv1d_block("conv", signals_channels_last, is_training=self.is_training, filters=8, kernel_size=5) block_config = [ (8, 3, True), (8, 3, False), (8, 3, True), (8, 3, False), (12, 3, True), (12, 3, False), (12, 3, True), (12, 3, False), (16, 3, True), (16, 3, False), (16, 3, False), (16, 3, True), (16, 3, False), (16, 3, False), (20, 3, True), (20, 3, False), ] for i, (filters, kernel_size, downsample) in enumerate(block_config): block = resnet1d_block("block_" + str(i + 1), block, is_training=self.is_training, filters=filters, kernel_size=kernel_size, downsample=downsample) with tf.variable_scope("global_max_pooling"): # pylint: disable=not-context-manager block = tf.reduce_max(block, axis=1) with tf.variable_scope("dense"): # pylint: disable=not-context-manager dense = tf.layers.dense(block, len(class_names), use_bias=False, name="dense") bnorm = tf.layers.batch_normalization(dense, training=self.is_training, name="batch_norm", fused=True) act = tf.nn.softplus(bnorm, name="activation") parameters = tf.identity(act, name="parameters") self.store_to_attr("parameters", parameters) predictions = tf.identity(act, name="predictions") self.store_to_attr("predictions", predictions) loss = tf.reduce_mean(tf.lbeta(parameters) - tf.reduce_sum((parameters - 1) * tf.log(targets_soft), axis=1), name="loss") tf.losses.add_loss(loss)
[docs]class DirichletModel(DirichletModelBase): """Dirichlet model with overloaded train and predict methods. * ``train`` method is identical to ``DirichletModelBase.train``, but also accepts ``args`` and ``kwargs``. * ``predict`` method splits the resulting tensor for ``parameters`` fetch using ``split_indices``. It also splits and aggregates results for ``predictions`` fetch to get class probabilities. """ @staticmethod def _get_dirichlet_mixture_stats(alpha): """Get mean and variance vectors of the mixture of Dirichlet distributions with equal weights and given parameters. Parameters ---------- alpha : 2-D ndarray Dirichlet distribution parameters along axis 1 for each mixture component. Returns ------- mean : 1-D ndarray Mean of the mixture. var : 1-D ndarray Variance of the mixture. """ alpha_sum = np.sum(alpha, axis=1)[:, np.newaxis] comp_m1 = alpha / alpha_sum comp_m2 = (alpha * (alpha + 1)) / (alpha_sum * (alpha_sum + 1)) mean = np.mean(comp_m1, axis=0) var = np.mean(comp_m2, axis=0) - mean**2 return mean, var
[docs] def train(self, fetches=None, feed_dict=None, use_lock=False, *args, **kwargs): """Train the model with the data provided. The only difference between ``DirichletModel.train`` and ``TFModel.train`` is that the former also accepts ``args`` and ``kwargs``. Parameters ---------- fetches : tf.Operation or tf.Tensor or array-like sequence of them Graph element to evaluate in addition to ``train_step``. feed_dict : dict A dictionary that maps graph elements to values. use_lock : bool If ``True``, the whole train step is locked, thus allowing for multithreading. Returns ------- output : same structure as ``fetches`` Calculated values for each element in ``fetches``. """ _ = args, kwargs return super().train(fetches, feed_dict, use_lock)
[docs] def predict(self, fetches=None, feed_dict=None, split_indices=None): # pylint: disable=arguments-differ """Get predictions on the data provided. Parameters ---------- fetches : tf.Operation or tf.Tensor or array-like sequence of them Graph element to evaluate. If ``fetches`` contains ``parameters`` tensor, the corresponding output is split using ``split_indices``. If ``fetches`` contains ``predictions`` tensor, the corresponding output is split using ``split_indices`` and then aggregated to get class probabilities. feed_dict : dict A dictionary that maps graph elements to values. split_indices : 1-D ndarray Indices used to split ``parameters`` and ``predictions`` tensors. Returns ------- output : same structure as ``fetches`` Calculated values for each element in ``fetches``. """ if isinstance(fetches, (list, tuple)): fetches_list = list(fetches) else: fetches_list = [fetches] output = super().predict(fetches_list, feed_dict) for i, fetch in enumerate(fetches_list): if fetch == "parameters": output[i] = np.split(output[i], split_indices) elif fetch == "predictions": class_names = self.class_names.eval(session=self.session) # pylint: disable=no-member class_names = [c.decode("utf-8") for c in class_names] n_classes = len(class_names) max_var = (n_classes - 1) / n_classes**2 alpha = np.split(output[i], split_indices) targets = feed_dict.get("targets") targets = [] if targets is None else [t[0] for t in np.split(targets, split_indices)] res = [] for a, t in zip_longest(alpha, targets): mean, var = self._get_dirichlet_mixture_stats(a) uncertainty = var[np.argmax(mean)] / max_var predictions_dict = {"target_pred": dict(zip(class_names, mean)), "uncertainty": uncertainty} if t is not None: predictions_dict["target_true"] = dict(zip(class_names, t)) res.append(predictions_dict) output[i] = res if isinstance(fetches, list): pass elif isinstance(fetches, tuple): output = tuple(output) else: output = output[0] return output