Beispiel #1
0
def real_vs_sampled_wrapper(output_type,
                            hparams,
                            save_file,
                            sess_idx,
                            dtype='test',
                            conditional=True,
                            max_frames=400,
                            frame_rate=20,
                            n_buffer=5,
                            xtick_locs=None,
                            frame_rate_beh=None,
                            format='png'):
    """Produce movie with (AE) reconstructed video and sampled video.

    This is a high-level function that loads the model described in the hparams dictionary and
    produces the necessary state sequences/samples. The sampled video can be completely
    unconditional (states and latents are sampled) or conditioned on the most likely state
    sequence.

    Parameters
    ----------
    output_type : :obj:`str`
        'plot' | 'movie' | 'both'
    hparams : :obj:`dict`
        needs to contain enough information to specify an autoencoder
    save_file : :obj:`str`
        full save file (path and filename)
    sess_idx : :obj:`int`, optional
        session index into data generator
    dtype : :obj:`str`, optional
        types of trials to make plot/video with; 'train' | 'val' | 'test'
    conditional : :obj:`bool`
        conditional vs unconditional samples; for creating reconstruction title
    max_frames : :obj:`int`, optional
        maximum number of frames to animate
    frame_rate : :obj:`float`, optional
        frame rate of saved movie
    n_buffer : :obj:`int`
        number of blank frames between animated trials if more one are needed to reach
        :obj:`max_frames`
    xtick_locs : :obj:`array-like`, optional
        tick locations in bin values for plot
    frame_rate_beh : :obj:`float`, optional
        behavioral video framerate; to properly relabel xticks
    format : :obj:`str`, optional
        any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg'

    Returns
    -------
    :obj:`matplotlib.figure.Figure`
        matplotlib figure handle if :obj:`output_type='plot'` or :obj:`output_type='both'`, else
        nothing returned (movie is saved)

    """
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    # check input - cannot create sampled movies for arhmm-labels models (no mapping from labels to
    # frames)
    if hparams['model_class'].find('labels') > -1:
        if output_type == 'both' or output_type == 'movie':
            print(
                'warning: cannot create video with "arhmm-labels" model; producing plots'
            )
            output_type = 'plot'

    # load latents and states (observed and sampled)
    model_output = get_model_latents_states(hparams,
                                            '',
                                            sess_idx=sess_idx,
                                            return_samples=50,
                                            cond_sampling=conditional)

    if output_type == 'both' or output_type == 'movie':

        # load in AE decoder
        if hparams.get('ae_model_path', None) is not None:
            ae_model_file = os.path.join(hparams['ae_model_path'],
                                         'best_val_model.pt')
            ae_arch = pickle.load(
                open(os.path.join(hparams['ae_model_path'], 'meta_tags.pkl'),
                     'rb'))
        else:
            hparams['session_dir'], sess_ids = get_session_dir(hparams)
            hparams['expt_dir'] = get_expt_dir(hparams)
            _, latents_file = get_transforms_paths('ae_latents', hparams,
                                                   sess_ids[sess_idx])
            ae_model_file = os.path.join(os.path.dirname(latents_file),
                                         'best_val_model.pt')
            ae_arch = pickle.load(
                open(
                    os.path.join(os.path.dirname(latents_file),
                                 'meta_tags.pkl'), 'rb'))
        print('loading model from %s' % ae_model_file)
        ae_model = AE(ae_arch)
        ae_model.load_state_dict(
            torch.load(ae_model_file,
                       map_location=lambda storage, loc: storage))
        ae_model.eval()

        n_channels = ae_model.hparams['n_input_channels']
        y_pix = ae_model.hparams['y_pixels']
        x_pix = ae_model.hparams['x_pixels']

        # push observed latents through ae decoder
        ims_recon = np.zeros((0, n_channels * y_pix, x_pix))
        i_trial = 0
        while ims_recon.shape[0] < max_frames:
            recon = ae_model.decoding(
                torch.tensor(model_output['latents'][dtype][i_trial]).float(), None, None). \
                cpu().detach().numpy()
            recon = np.concatenate(
                [recon[:, i] for i in range(recon.shape[1])], axis=1)
            zero_frames = np.zeros((n_buffer, n_channels * y_pix,
                                    x_pix))  # add a few black frames
            ims_recon = np.concatenate((ims_recon, recon, zero_frames), axis=0)
            i_trial += 1

        # push sampled latents through ae decoder
        ims_recon_samp = np.zeros((0, n_channels * y_pix, x_pix))
        i_trial = 0
        while ims_recon_samp.shape[0] < max_frames:
            recon = ae_model.decoding(
                torch.tensor(model_output['latents_gen'][i_trial]).float(),
                None, None).cpu().detach().numpy()
            recon = np.concatenate(
                [recon[:, i] for i in range(recon.shape[1])], axis=1)
            zero_frames = np.zeros((n_buffer, n_channels * y_pix,
                                    x_pix))  # add a few black frames
            ims_recon_samp = np.concatenate(
                (ims_recon_samp, recon, zero_frames), axis=0)
            i_trial += 1

        make_real_vs_sampled_movies(ims_recon,
                                    ims_recon_samp,
                                    conditional=conditional,
                                    save_file=save_file,
                                    frame_rate=frame_rate)

    if output_type == 'both' or output_type == 'plot':

        i_trial = 0
        latents = model_output['latents'][dtype][i_trial][:max_frames]
        states = model_output['states'][dtype][i_trial][:max_frames]
        latents_samp = model_output['latents_gen'][i_trial][:max_frames]
        if not conditional:
            states_samp = model_output['states_gen'][i_trial][:max_frames]
        else:
            states_samp = []

        fig = plot_real_vs_sampled(latents,
                                   latents_samp,
                                   states,
                                   states_samp,
                                   save_file=save_file,
                                   xtick_locs=xtick_locs,
                                   frame_rate=frame_rate_beh,
                                   format=format)

    if output_type == 'movie':
        return None
    elif output_type == 'both' or output_type == 'plot':
        return fig
    else:
        raise ValueError('"%s" is an invalid output_type' % output_type)
Beispiel #2
0
def make_syllable_movies_wrapper(hparams,
                                 save_file,
                                 sess_idx=0,
                                 dtype='test',
                                 max_frames=400,
                                 frame_rate=10,
                                 min_threshold=0,
                                 n_buffer=5,
                                 n_pre_frames=3,
                                 n_rows=None,
                                 single_syllable=None):
    """Present video clips of each individual syllable in separate panels.

    This is a high-level function that loads the arhmm model described in the hparams dictionary
    and produces the necessary states/video frames.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify an arhmm
    save_file : :obj:`str`
        full save file (path and filename)
    sess_idx : :obj:`int`, optional
        session index into data generator
    dtype : :obj:`str`, optional
        types of trials to make video with; 'train' | 'val' | 'test'
    max_frames : :obj:`int`, optional
        maximum number of frames to animate
    frame_rate : :obj:`float`, optional
        frame rate of saved movie
    min_threshold : :obj:`int`, optional
        minimum number of frames in a syllable run to be considered for movie
    n_buffer : :obj:`int`
        number of blank frames between syllable instances
    n_pre_frames : :obj:`int`
        number of behavioral frames to precede each syllable instance
    n_rows : :obj:`int` or :obj:`NoneType`
        number of rows in output movie
    single_syllable : :obj:`int` or :obj:`NoneType`
        choose only a single state for movie

    """
    from behavenet.data.data_generator import ConcatSessionsGenerator
    from behavenet.data.utils import get_data_generator_inputs
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import experiment_exists
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    # load images, latents, and states
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)
    hparams['load_videos'] = True
    hparams, signals, transforms, paths = get_data_generator_inputs(
        hparams, sess_ids)
    data_generator = ConcatSessionsGenerator(
        hparams['data_dir'],
        sess_ids,
        signals_list=[signals[sess_idx]],
        transforms_list=[transforms[sess_idx]],
        paths_list=[paths[sess_idx]],
        device='cpu',
        as_numpy=True,
        batch_load=False,
        rng_seed=hparams['rng_seed_data'])
    ims_orig = data_generator.datasets[sess_idx].data['images']
    del data_generator  # free up memory

    # get tt version number
    _, version = experiment_exists(hparams, which_version=True)
    print('producing syllable videos for arhmm %s' % version)
    # load latents/labels
    if hparams['model_class'].find('labels') > -1:
        from behavenet.data.utils import load_labels_like_latents
        latents = load_labels_like_latents(hparams, sess_ids, sess_idx)
    else:
        _, latents_file = get_transforms_paths('ae_latents', hparams,
                                               sess_ids[sess_idx])
        with open(latents_file, 'rb') as f:
            latents = pickle.load(f)
    trial_idxs = latents['trials'][dtype]
    # load model
    model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version,
                              'best_val_model.pt')
    with open(model_file, 'rb') as f:
        hmm = pickle.load(f)
    # infer discrete states
    states = [
        hmm.most_likely_states(latents['latents'][s])
        for s in latents['trials'][dtype]
    ]
    if len(states) == 0:
        raise ValueError('No latents for dtype=%s' % dtype)

    # find runs of discrete states; state indices is a list, each entry of which is a np array with
    # shape (n_state_instances, 3), where the 3 values are:
    # chunk_idx, chunk_start_idx, chunk_end_idx
    # chunk_idx is in [0, n_chunks], and indexes trial_idxs
    state_indices = get_discrete_chunks(states, include_edges=True)
    K = len(state_indices)

    # get all example over minimum state length threshold
    over_threshold_instances = [[] for _ in range(K)]
    for i_state in range(K):
        if state_indices[i_state].shape[0] > 0:
            state_lens = np.diff(state_indices[i_state][:, 1:3], axis=1)
            over_idxs = state_lens > min_threshold
            over_threshold_instances[i_state] = state_indices[i_state][
                over_idxs[:, 0]]
            np.random.shuffle(
                over_threshold_instances[i_state])  # shuffle instances

    make_syllable_movies(ims_orig=ims_orig,
                         state_list=over_threshold_instances,
                         trial_idxs=trial_idxs,
                         save_file=save_file,
                         max_frames=max_frames,
                         frame_rate=frame_rate,
                         n_buffer=n_buffer,
                         n_pre_frames=n_pre_frames,
                         n_rows=n_rows,
                         single_syllable=single_syllable)
Beispiel #3
0
def get_model_latents_states(hparams,
                             version,
                             sess_idx=0,
                             return_samples=0,
                             cond_sampling=False,
                             dtype='test',
                             dtypes=['train', 'val', 'test'],
                             rng_seed=0):
    """Return arhmm defined in :obj:`hparams` with associated latents and states.

    Can also return sampled latents and states.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify an arhmm
    version : :obj:`str` or :obj:`int`
        test tube model version (can be 'best')
    sess_idx : :obj:`int`, optional
        session index into data generator
    return_samples : :obj:`int`, optional
        number of trials to sample from model
    cond_sampling : :obj:`bool`, optional
        if :obj:`True` return samples conditioned on most likely state sequence; else return
        unconditioned samples
    dtype : :obj:`str`, optional
        trial type to use for conditonal sampling; 'train' | 'val' | 'test'
    dtypes : :obj:`array-like`, optional
        trial types for which to collect latents and states
    rng_seed : :obj:`int`, optional
        random number generator seed to control sampling

    Returns
    -------
    :obj:`dict`
        - 'model' (:obj:`ssm.HMM` object)
        - 'latents' (:obj:`dict`): latents from train, val and test trials
        - 'states' (:obj:`dict`): states from train, val and test trials
        - 'trial_idxs' (:obj:`dict`): trial indices from train, val and test trials
        - 'latents_gen' (:obj:`list`)
        - 'states_gen' (:obj:`list`)

    """
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import experiment_exists
    from behavenet.fitting.utils import get_best_model_version
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)

    # get version/model
    if version == 'best':
        version = get_best_model_version(hparams['expt_dir'],
                                         measure='val_loss',
                                         best_def='max')[0]
    else:
        _, version = experiment_exists(hparams, which_version=True)
    if version is None:
        raise FileNotFoundError(
            'Could not find the specified model version in %s' %
            hparams['expt_dir'])

    # load model
    model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version,
                              'best_val_model.pt')
    with open(model_file, 'rb') as f:
        hmm = pickle.load(f)

    # load latents/labels
    if hparams['model_class'].find('labels') > -1:
        from behavenet.data.utils import load_labels_like_latents
        all_latents = load_labels_like_latents(hparams, sess_ids, sess_idx)
    else:
        _, latents_file = get_transforms_paths('ae_latents', hparams,
                                               sess_ids[sess_idx])
        with open(latents_file, 'rb') as f:
            all_latents = pickle.load(f)

    # collect inferred latents/states
    trial_idxs = {}
    latents = {}
    states = {}
    for data_type in dtypes:
        trial_idxs[data_type] = all_latents['trials'][data_type]
        latents[data_type] = [
            all_latents['latents'][i_trial]
            for i_trial in trial_idxs[data_type]
        ]
        states[data_type] = [
            hmm.most_likely_states(x) for x in latents[data_type]
        ]

    # collect sampled latents/states
    if return_samples > 0:
        states_gen = []
        np.random.seed(rng_seed)
        if cond_sampling:
            n_latents = latents[dtype][0].shape[1]
            latents_gen = [
                np.zeros((len(state_seg), n_latents))
                for state_seg in states[dtype]
            ]
            for i_seg, state_seg in enumerate(states[dtype]):
                for i_t in range(len(state_seg)):
                    if i_t >= 1:
                        latents_gen[i_seg][i_t] = hmm.observations.sample_x(
                            states[dtype][i_seg][i_t],
                            latents_gen[i_seg][:i_t],
                            input=np.zeros(0))
                    else:
                        latents_gen[i_seg][i_t] = hmm.observations.sample_x(
                            states[dtype][i_seg][i_t],
                            latents[dtype][i_seg][0].reshape((1, n_latents)),
                            input=np.zeros(0))
        else:
            latents_gen = []
            offset = 200
            for i in range(return_samples):
                these_states_gen, these_latents_gen = hmm.sample(
                    latents[dtype][0].shape[0] + offset)
                latents_gen.append(these_latents_gen[offset:])
                states_gen.append(these_states_gen[offset:])
    else:
        latents_gen = []
        states_gen = []

    return_dict = {
        'model': hmm,
        'latents': latents,
        'states': states,
        'trial_idxs': trial_idxs,
        'latents_gen': latents_gen,
        'states_gen': states_gen,
    }
    return return_dict
Beispiel #4
0
def get_r2s_by_trial(hparams, model_types):
    """For a given session, load R^2 metrics from all decoders defined by hparams.

    Parameters
    ----------

    hparams : :obj:`dict`
        needs to contain enough information to specify decoders
    model_types : :obj:`list` of :obj:`strs`
        'mlp' | 'mlp-mv' | 'lstm'

    Returns
    -------
    :obj:`pd.DataFrame`
        pandas dataframe of decoder validation metrics

    """

    dataset = _get_dataset_str(hparams)
    region_names = get_region_list(hparams)

    metrics = []
    model_idx = 0
    model_counter = 0
    for region in region_names:
        hparams['region'] = region
        for model_type in model_types:

            hparams['session_dir'], _ = get_session_dir(
                hparams, session_source=hparams.get('all_source', 'save'))
            expt_dir = get_expt_dir(hparams,
                                    model_type=model_type,
                                    model_class=hparams['model_class'],
                                    expt_name=hparams['experiment_name'])

            # gather all versions
            try:
                versions = get_subdirs(expt_dir)
            except Exception:
                print('No models in %s; skipping' % expt_dir)

            # load csv files with model metrics (saved out from test tube)
            for i, version in enumerate(versions):
                # read metrics csv file
                model_dir = os.path.join(expt_dir, version)
                try:
                    metric = pd.read_csv(os.path.join(model_dir,
                                                      'metrics.csv'))
                    model_counter += 1
                except FileNotFoundError:
                    continue
                with open(os.path.join(model_dir, 'meta_tags.pkl'), 'rb') as f:
                    hparams = pickle.load(f)
                # append model info to metrics ()
                version_num = version[8:]
                metric['version'] = str('version_%i' % model_idx + version_num)
                metric['region'] = region
                metric['dataset'] = dataset
                metric['model_type'] = model_type
                for key, val in hparams.items():
                    if isinstance(val, (str, int, float)):
                        metric[key] = val
                metrics.append(metric)

            model_idx += 10000  # assumes no more than 10k model versions/expt
    # put everything in pandas dataframe
    metrics_df = pd.concat(metrics, sort=False)
    return metrics_df
Beispiel #5
0
def load_metrics_csv_as_df(hparams,
                           lab,
                           expt,
                           metrics_list,
                           test=False,
                           version='best'):
    """Load metrics csv file and return as a pandas dataframe for easy plotting.

    Parameters
    ----------
    hparams : :obj:`dict`
        requires `sessions_csv`, `multisession`, `lab`, `expt`, `animal` and `session`
    lab : :obj:`str`
        for `get_lab_example`
    expt : :obj:`str`
        for `get_lab_example`
    metrics_list : :obj:`list`
        names of metrics to pull from csv; do not prepend with 'tr', 'val', or 'test'
    test : :obj:`bool`
        True to only return test values (computed once at end of training)
    version: :obj:`str`
        `best` to find best model in tt expt, None to find model with hyperparams defined in
        `hparams`, int to load specific model

    Returns
    -------
    :obj:`pandas.DataFrame` object

    """

    # programmatically fill out other hparams options
    get_lab_example(hparams, lab, expt)
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)

    # find metrics csv file
    if version is 'best':
        version = get_best_model_version(hparams['expt_dir'])[0]
    elif isinstance(version, int):
        version = version
    else:
        _, version = experiment_exists(hparams, which_version=True)
    version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % version)
    metric_file = os.path.join(version_dir, 'metrics.csv')
    metrics = pd.read_csv(metric_file)

    # collect data from csv file
    sess_ids = read_session_info_from_csv(
        os.path.join(version_dir, 'session_info.csv'))
    sess_ids_strs = []
    for sess_id in sess_ids:
        sess_ids_strs.append(
            str('%s/%s' % (sess_id['animal'], sess_id['session'])))
    metrics_df = []
    for i, row in metrics.iterrows():
        dataset = 'all' if row['dataset'] == -1 else sess_ids_strs[
            row['dataset']]
        if test:
            test_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'test'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **test_dict, 'loss': metric,
                            'val': row['test_%s' % metric]
                        },
                        index=[0]))
        else:
            # make dict for val data
            val_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'val'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **val_dict, 'loss': metric,
                            'val': row['val_%s' % metric]
                        },
                        index=[0]))
            # NOTE: grayed out lines are old version that returns a single dataframe row containing
            # all losses per epoch; new way creates one row per loss, making it easy to use with
            # seaborn's FacetGrid object for multi-axis plotting for metric in metrics_list:
            #     val_dict[metric] = row['val_%s' % metric]
            # metrics_df.append(pd.DataFrame(val_dict, index=[0]))
            # make dict for train data
            tr_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'train'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **tr_dict, 'loss': metric,
                            'val': row['tr_%s' % metric]
                        },
                        index=[0]))
            # for metric in metrics_list:
            #     tr_dict[metric] = row['tr_%s' % metric]
            # metrics_df.append(pd.DataFrame(tr_dict, index=[0]))
    return pd.concat(metrics_df, sort=True)
    def test_get_session_dir(self):

        hparams = {'data_dir': self.tmpdir, 'save_dir': self.tmpdir}

        # ------------------------------------------------------------
        # csv contained in multisession directory
        # ------------------------------------------------------------
        # single session from one animal
        hparams['lab'] = 'lab0'
        hparams['expt'] = 'expt0'
        hparams['animal'] = 'animal1'
        hparams['session'] = 'session-00'
        hparams['sessions_csv'] = self.l0e0a1_csv
        sess_dir_ = os.path.join(
            hparams['save_dir'], hparams['lab'], hparams['expt'], hparams['animal'],
            hparams['session'])
        sess_dir, sess_single = utils.get_session_dir(hparams, session_source='save')
        assert sess_dir == sess_dir_
        assert sess_single == [self.sess_ids[i] for i in self.l0e0a1_idxs]

        # multiple sessions from one animal
        hparams['lab'] = 'lab0'
        hparams['expt'] = 'expt0'
        hparams['animal'] = 'animal0'
        hparams['sessions_csv'] = self.l0e0a0_csv
        sess_dir_ = os.path.join(
            hparams['save_dir'], hparams['lab'], hparams['expt'], hparams['animal'],
            'multisession-%02i' % self.l0e0a0_id)
        sess_dir, sess_single = utils.get_session_dir(hparams, session_source='save')
        assert sess_dir == sess_dir_
        assert sess_single == [self.sess_ids[i] for i in self.l0e0a0_idxs]

        # multiple sessions from one experiment
        hparams['lab'] = 'lab0'
        hparams['expt'] = 'expt0'
        hparams['sessions_csv'] = self.l0e0_csv
        sess_dir_ = os.path.join(
            hparams['save_dir'], hparams['lab'], hparams['expt'],
            'multisession-%02i' % self.l0e0_id)
        sess_dir, sess_single = utils.get_session_dir(hparams, session_source='save')
        assert sess_dir == sess_dir_
        assert sess_single == [self.sess_ids[i] for i in self.l0e0_idxs]

        # multiple sessions from one lab
        hparams['lab'] = 'lab0'
        hparams['sessions_csv'] = self.l0_csv
        sess_dir_ = os.path.join(
            hparams['save_dir'], hparams['lab'], 'multisession-%02i' % self.l0_id)
        sess_dir, sess_single = utils.get_session_dir(hparams, session_source='save')
        assert sess_dir == sess_dir_
        assert sess_single == [self.sess_ids[i] for i in self.l0_idxs]

        # multiple sessions from multiple labs
        hparams['sessions_csv'] = self.l_csv
        with pytest.raises(NotImplementedError):
            utils.get_session_dir(hparams, session_source='save')

        # ------------------------------------------------------------
        # use 'all' in hparams instead of csv file
        # ------------------------------------------------------------
        hparams['sessions_csv'] = ''

        # all labs
        hparams['lab'] = 'all'
        with pytest.raises(NotImplementedError):
            utils.get_session_dir(hparams, session_source='save')

        # all experiments
        hparams['lab'] = 'lab0'
        hparams['expt'] = 'all'
        sess_dir_ = os.path.join(
            hparams['save_dir'], hparams['lab'], 'multisession-%02i' % self.l0_id)
        sess_dir, sess_single = utils.get_session_dir(hparams, session_source='save')
        sess_single = [dict2str(d) for d in sess_single]
        sess_single_ = [dict2str(self.sess_ids[i]) for i in self.l0_idxs]
        assert sess_dir == sess_dir_
        assert sorted(sess_single) == sorted(sess_single_)

        # all animals
        hparams['lab'] = 'lab0'
        hparams['expt'] = 'expt0'
        hparams['animal'] = 'all'
        sess_dir_ = os.path.join(
            hparams['save_dir'], hparams['lab'], hparams['expt'],
            'multisession-%02i' % self.l0e0_id)
        sess_dir, sess_single = utils.get_session_dir(hparams, session_source='save')
        sess_single = [dict2str(d) for d in sess_single]
        sess_single_ = [dict2str(self.sess_ids[i]) for i in self.l0e0_idxs]
        assert sess_dir == sess_dir_
        assert sorted(sess_single) == sorted(sess_single_)

        # all sessions
        hparams['lab'] = 'lab0'
        hparams['expt'] = 'expt0'
        hparams['animal'] = 'animal0'
        hparams['session'] = 'all'
        sess_dir_ = os.path.join(
            hparams['save_dir'], hparams['lab'], hparams['expt'], hparams['animal'],
            'multisession-%02i' % self.l0e0a0_id)
        sess_dir, sess_single = utils.get_session_dir(hparams, session_source='save')
        sess_single = [dict2str(d) for d in sess_single]
        sess_single_ = [dict2str(self.sess_ids[i]) for i in self.l0e0a0_idxs]
        assert sess_dir == sess_dir_
        assert sorted(sess_single) == sorted(sess_single_)

        # single session
        hparams['lab'] = 'lab0'
        hparams['expt'] = 'expt0'
        hparams['animal'] = 'animal0'
        hparams['session'] = 'session-%02i' % self.l0e0a0s0_id
        sess_dir_ = os.path.join(
            hparams['save_dir'], hparams['lab'], hparams['expt'], hparams['animal'],
            hparams['session'])
        sess_dir, sess_single = utils.get_session_dir(hparams, session_source='save')
        sess_single = [dict2str(d) for d in sess_single]
        sess_single_ = [dict2str(self.sess_ids[i]) for i in self.l0e0a0s0_idxs]
        assert sess_dir == sess_dir_
        assert sorted(sess_single) == sorted(sess_single_)

        # ------------------------------------------------------------
        # use 'all' to define level, then define existing multisession
        # ------------------------------------------------------------
        hparams['lab'] = 'lab0'
        hparams['expt'] = 'expt0'
        hparams['animal'] = 'animal0'
        hparams['session'] = 'all'
        hparams['multisession'] = 1
        sess_dir_ = os.path.join(
            hparams['save_dir'], hparams['lab'], hparams['expt'], hparams['animal'],
            'multisession-%02i' % self.l0e0a0m1_id)
        sess_dir, sess_single = utils.get_session_dir(hparams, session_source='save')
        sess_single = [dict2str(d) for d in sess_single]
        sess_single_ = [dict2str(self.sess_ids[i]) for i in self.l0e0a0m1_idxs]
        assert sess_dir == sess_dir_
        assert sorted(sess_single) == sorted(sess_single_)

        # TODO: return correct single session if multisession returns single

        # ------------------------------------------------------------
        #  use 'all' to define level, no existing multisession
        # ------------------------------------------------------------
        hparams['lab'] = 'lab1'
        hparams['expt'] = 'expt0'
        hparams['animal'] = 'animal0'
        hparams['session'] = 'all'
        hparams['multisession'] = None
        sess_dir_ = os.path.join(
            hparams['save_dir'], hparams['lab'], hparams['expt'], hparams['animal'],
            'multisession-%02i' % 0)
        sess_dir, sess_single = utils.get_session_dir(hparams, session_source='save')
        sess_single = [dict2str(d) for d in sess_single]
        sess_single_ = [dict2str(self.sess_ids[i]) for i in self.l1e0a0_idxs]
        assert sess_dir == sess_dir_
        assert sorted(sess_single) == sorted(sess_single_)

        # ------------------------------------------------------------
        # other
        # ------------------------------------------------------------
        # bad 'session_source'
        with pytest.raises(ValueError):
            utils.get_session_dir(hparams, session_source='test')