def test_export_session_info_to_csv(self):

        # create csv file
        sess_idxs = [0, 1, 2, 4, 5]
        csv_file = os.path.join(self.tmpdir, 'session_info.csv')
        utils.export_session_info_to_csv(self.tmpdir, [self.sess_ids[i] for i in sess_idxs])
        # load csv file
        sessions = utils.read_session_info_from_csv(csv_file)
        assert sessions == [self.sess_ids[i] for i in sess_idxs]
Esempio n. 2
0
def export_train_plots(hparams,
                       dtype,
                       loss_type='mse',
                       save_file=None,
                       format='png'):
    """Export plot with MSE/LL as a function of training epochs.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify the desired model (autoencoder, arhmm, etc.)
    dtype : :obj:`str`
        type of trials to use for plotting: 'train' | 'val' (metrics are not computed for 'test'
        trials throughout training)
    loss_type : :obj:`str`, optional
        'mse' | 'll'
    save_file : :obj:`str` or :obj:`NoneType`, optional
        full filename (absolute path) for saving plot; if :obj:`NoneType`, plot is displayed
    format : :obj:`str`
        file format of plot, e.g. 'png' | 'pdf' | 'jpeg'

    """
    import os
    import pandas as pd
    import seaborn as sns
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    from behavenet.fitting.utils import read_session_info_from_csv

    mpl.use('Agg')  # deal with display-less machines
    sns.set_style('white')
    sns.set_context('talk')

    # find metrics csv file
    version_dir = os.path.join(hparams['expt_dir'],
                               'version_%i' % hparams['version'])
    metric_file = os.path.join(version_dir, 'metrics.csv')
    metrics = pd.read_csv(metric_file)

    # collect data from csv file
    sess_ids = read_session_info_from_csv(
        os.path.join(version_dir, 'session_info.csv'))
    sess_ids_strs = []
    for sess_id in sess_ids:
        sess_ids_strs.append(
            str('%s/%s' % (sess_id['animal'], sess_id['session'])))
    metrics_df = []
    for i, row in metrics.iterrows():
        dataset = 'all' if row['dataset'] == -1 else sess_ids_strs[
            row['dataset']]
        metrics_df.append(
            pd.DataFrame(
                {
                    'dataset': dataset,
                    'epoch': row['epoch'],
                    'loss': row['val_loss'],
                    'dtype': 'val',
                },
                index=[0]))
        metrics_df.append(
            pd.DataFrame(
                {
                    'dataset': dataset,
                    'epoch': row['epoch'],
                    'loss': row['tr_loss'],
                    'dtype': 'train',
                },
                index=[0]))
    metrics_df = pd.concat(metrics_df)

    # plot data
    data_queried = metrics_df[(metrics_df.dtype == dtype)
                              & (metrics_df.epoch > 0)
                              & ~pd.isna(metrics_df.loss)]
    splt = sns.relplot(x='epoch',
                       y='loss',
                       hue='dataset',
                       kind='line',
                       data=data_queried)
    splt.ax.set_xlabel('Epoch')
    if loss_type == 'mse':
        splt.ax.set_yscale('log')
        splt.ax.set_ylabel('MSE per pixel')
    elif loss_type == 'll':
        splt.ax.set_ylabel('Neg log prob per datapoint')
    else:
        raise ValueError('"%s" is an invalid loss type' % loss_type)
    title_str = 'Validation' if dtype == 'val' else 'Training'
    plt.title('%s loss' % title_str)

    if save_file is not None:
        plt.savefig(str('%s.%s' % (save_file, format)), dpi=300, format=format)
        plt.close()
    else:
        plt.show()

    return splt
Esempio n. 3
0
def load_metrics_csv_as_df(hparams,
                           lab,
                           expt,
                           metrics_list,
                           test=False,
                           version='best'):
    """Load metrics csv file and return as a pandas dataframe for easy plotting.

    Parameters
    ----------
    hparams : :obj:`dict`
        requires `sessions_csv`, `multisession`, `lab`, `expt`, `animal` and `session`
    lab : :obj:`str`
        for `get_lab_example`
    expt : :obj:`str`
        for `get_lab_example`
    metrics_list : :obj:`list`
        names of metrics to pull from csv; do not prepend with 'tr', 'val', or 'test'
    test : :obj:`bool`
        True to only return test values (computed once at end of training)
    version: :obj:`str`
        `best` to find best model in tt expt, None to find model with hyperparams defined in
        `hparams`, int to load specific model

    Returns
    -------
    :obj:`pandas.DataFrame` object

    """

    # programmatically fill out other hparams options
    get_lab_example(hparams, lab, expt)
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)

    # find metrics csv file
    if version is 'best':
        version = get_best_model_version(hparams['expt_dir'])[0]
    elif isinstance(version, int):
        version = version
    else:
        _, version = experiment_exists(hparams, which_version=True)
    version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % version)
    metric_file = os.path.join(version_dir, 'metrics.csv')
    metrics = pd.read_csv(metric_file)

    # collect data from csv file
    sess_ids = read_session_info_from_csv(
        os.path.join(version_dir, 'session_info.csv'))
    sess_ids_strs = []
    for sess_id in sess_ids:
        sess_ids_strs.append(
            str('%s/%s' % (sess_id['animal'], sess_id['session'])))
    metrics_df = []
    for i, row in metrics.iterrows():
        dataset = 'all' if row['dataset'] == -1 else sess_ids_strs[
            row['dataset']]
        if test:
            test_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'test'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **test_dict, 'loss': metric,
                            'val': row['test_%s' % metric]
                        },
                        index=[0]))
        else:
            # make dict for val data
            val_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'val'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **val_dict, 'loss': metric,
                            'val': row['val_%s' % metric]
                        },
                        index=[0]))
            # NOTE: grayed out lines are old version that returns a single dataframe row containing
            # all losses per epoch; new way creates one row per loss, making it easy to use with
            # seaborn's FacetGrid object for multi-axis plotting for metric in metrics_list:
            #     val_dict[metric] = row['val_%s' % metric]
            # metrics_df.append(pd.DataFrame(val_dict, index=[0]))
            # make dict for train data
            tr_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'train'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **tr_dict, 'loss': metric,
                            'val': row['tr_%s' % metric]
                        },
                        index=[0]))
            # for metric in metrics_list:
            #     tr_dict[metric] = row['tr_%s' % metric]
            # metrics_df.append(pd.DataFrame(tr_dict, index=[0]))
    return pd.concat(metrics_df, sort=True)
    def test_read_session_info_from_csv(self):

        sessions = utils.read_session_info_from_csv(self.l0e0a0_csv)
        assert sessions == [self.sess_ids[i] for i in self.l0e0a0_idxs]