def plot_activations(model, x_test, path, data_format=None):
    """Plot activations of a network.

    Parameters
    ----------

    model: keras.models.Model
        Keras model.

    x_test: ndarray
        The samples.

    path: str
        Where to save plot.

    data_format: Optional[str]
        One of 'channels_first' or 'channels_last'.
    """

    from snntoolbox.conversion.utils import get_activations_batch
    from snntoolbox.simulation.utils import get_sample_activity_from_batch

    activations_batch = get_activations_batch(model, x_test)
    activations = get_sample_activity_from_batch(activations_batch, 0)
    for i in range(len(activations)):
        label = activations[i][1]
        print("Plotting layer {}".format(label))
        if not os.path.exists(path):
            os.makedirs(path)
        j = str(i) if i > 9 else '0' + str(i)
        plot_layer_activity(activations[i],
                            j + label,
                            path,
                            data_format=data_format)
def output_graphs(plot_vars, config, path=None, idx=0, data_format=None):
    """Wrapper function to display / save a number of plots.

    Parameters
    ----------

    plot_vars: dict
        Example items:

        - spiketrains_n_b_l_t: list[tuple[np.array, str]]
            Each entry in ``spiketrains_batch`` contains a tuple
            ``(spiketimes, label)`` for each layer of the network (for the
            first batch only, and excluding ``Flatten`` layers).
            ``spiketimes`` is an array where the last index contains the spike
            times of the specific neuron, and the first indices run over the
            number of neurons in the layer:
            (batch_size, n_chnls, n_rows, n_cols, duration)
            ``label`` is a string specifying both the layer type and the index,
            e.g. ``'03Dense'``.

        - activations_n_b_l: list[tuple[np.array, str]]
            Activations of the ANN.

        - spikecounts_n_b_l: list[tuple[np.array, str]]
            Spikecounts of the SNN. Used to compute spikerates.

    config: configparser.ConfigParser
        Settings.

    path: Optional[str]
        If not ``None``, specifies where to save the resulting image. Else,
        display plots without saving.

    idx: int
        The index of the sample to display. Defaults to 0.

    data_format: Optional[str]
        One of 'channels_first' or 'channels_last'.
    """

    from snntoolbox.simulation.utils import spiketrains_to_rates
    from snntoolbox.simulation.utils import get_sample_activity_from_batch

    if plot_vars == {}:
        return

    if path is not None:
        print("Saving plots of one sample to {}...\n".format(path))

    plot_keys = eval(config.get('output', 'plot_vars'))
    duration = config.getint('simulation', 'duration')

    if 'activations_n_b_l' in plot_vars:
        plot_vars['activations_n_l'] = get_sample_activity_from_batch(
            plot_vars['activations_n_b_l'], idx)
    if 'spiketrains_n_b_l_t' in plot_vars:
        plot_vars['spiketrains_n_l_t'] = get_sample_activity_from_batch(
            plot_vars['spiketrains_n_b_l_t'], idx)
        if any({'spikerates', 'correlation', 'hist_spikerates_activations'}
               & plot_keys):
            if 'spikerates_n_b_l' not in plot_vars:
                plot_vars['spikerates_n_b_l'] = spiketrains_to_rates(
                    plot_vars['spiketrains_n_b_l_t'], duration,
                    config.get('conversion', 'spike_code'))
            plot_vars['spikerates_n_l'] = get_sample_activity_from_batch(
                plot_vars['spikerates_n_b_l'], idx)

    plot_layer_summaries(plot_vars, config, path, data_format)

    print("Plotting batch run statistics...")
    if 'spikecounts' in plot_keys:
        plot_spikecount_vs_time(plot_vars['spiketrains_n_b_l_t'], duration,
                                config.getfloat('simulation', 'dt'), path)
    if 'correlation' in plot_keys:
        plot_pearson_coefficients(plot_vars['spikerates_n_b_l'],
                                  plot_vars['activations_n_b_l'], config, path)
    if 'hist_spikerates_activations' in plot_keys:
        s = []
        a = []
        for i in range(len(plot_vars['spikerates_n_b_l'])):
            s += list(plot_vars['spikerates_n_b_l'][i][0].flatten())
            a += list(plot_vars['activations_n_b_l'][i][0].flatten())
        plot_hist({'Spikerates': s, 'Activations': a}, path=path)
    print("Done.\n")