示例#1
0
def real_vs_sampled_wrapper(output_type,
                            hparams,
                            save_file,
                            sess_idx,
                            dtype='test',
                            conditional=True,
                            max_frames=400,
                            frame_rate=20,
                            n_buffer=5,
                            xtick_locs=None,
                            frame_rate_beh=None,
                            format='png'):
    """Produce movie with (AE) reconstructed video and sampled video.

    This is a high-level function that loads the model described in the hparams dictionary and
    produces the necessary state sequences/samples. The sampled video can be completely
    unconditional (states and latents are sampled) or conditioned on the most likely state
    sequence.

    Parameters
    ----------
    output_type : :obj:`str`
        'plot' | 'movie' | 'both'
    hparams : :obj:`dict`
        needs to contain enough information to specify an autoencoder
    save_file : :obj:`str`
        full save file (path and filename)
    sess_idx : :obj:`int`, optional
        session index into data generator
    dtype : :obj:`str`, optional
        types of trials to make plot/video with; 'train' | 'val' | 'test'
    conditional : :obj:`bool`
        conditional vs unconditional samples; for creating reconstruction title
    max_frames : :obj:`int`, optional
        maximum number of frames to animate
    frame_rate : :obj:`float`, optional
        frame rate of saved movie
    n_buffer : :obj:`int`
        number of blank frames between animated trials if more one are needed to reach
        :obj:`max_frames`
    xtick_locs : :obj:`array-like`, optional
        tick locations in bin values for plot
    frame_rate_beh : :obj:`float`, optional
        behavioral video framerate; to properly relabel xticks
    format : :obj:`str`, optional
        any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg'

    Returns
    -------
    :obj:`matplotlib.figure.Figure`
        matplotlib figure handle if :obj:`output_type='plot'` or :obj:`output_type='both'`, else
        nothing returned (movie is saved)

    """
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    # check input - cannot create sampled movies for arhmm-labels models (no mapping from labels to
    # frames)
    if hparams['model_class'].find('labels') > -1:
        if output_type == 'both' or output_type == 'movie':
            print(
                'warning: cannot create video with "arhmm-labels" model; producing plots'
            )
            output_type = 'plot'

    # load latents and states (observed and sampled)
    model_output = get_model_latents_states(hparams,
                                            '',
                                            sess_idx=sess_idx,
                                            return_samples=50,
                                            cond_sampling=conditional)

    if output_type == 'both' or output_type == 'movie':

        # load in AE decoder
        if hparams.get('ae_model_path', None) is not None:
            ae_model_file = os.path.join(hparams['ae_model_path'],
                                         'best_val_model.pt')
            ae_arch = pickle.load(
                open(os.path.join(hparams['ae_model_path'], 'meta_tags.pkl'),
                     'rb'))
        else:
            hparams['session_dir'], sess_ids = get_session_dir(hparams)
            hparams['expt_dir'] = get_expt_dir(hparams)
            _, latents_file = get_transforms_paths('ae_latents', hparams,
                                                   sess_ids[sess_idx])
            ae_model_file = os.path.join(os.path.dirname(latents_file),
                                         'best_val_model.pt')
            ae_arch = pickle.load(
                open(
                    os.path.join(os.path.dirname(latents_file),
                                 'meta_tags.pkl'), 'rb'))
        print('loading model from %s' % ae_model_file)
        ae_model = AE(ae_arch)
        ae_model.load_state_dict(
            torch.load(ae_model_file,
                       map_location=lambda storage, loc: storage))
        ae_model.eval()

        n_channels = ae_model.hparams['n_input_channels']
        y_pix = ae_model.hparams['y_pixels']
        x_pix = ae_model.hparams['x_pixels']

        # push observed latents through ae decoder
        ims_recon = np.zeros((0, n_channels * y_pix, x_pix))
        i_trial = 0
        while ims_recon.shape[0] < max_frames:
            recon = ae_model.decoding(
                torch.tensor(model_output['latents'][dtype][i_trial]).float(), None, None). \
                cpu().detach().numpy()
            recon = np.concatenate(
                [recon[:, i] for i in range(recon.shape[1])], axis=1)
            zero_frames = np.zeros((n_buffer, n_channels * y_pix,
                                    x_pix))  # add a few black frames
            ims_recon = np.concatenate((ims_recon, recon, zero_frames), axis=0)
            i_trial += 1

        # push sampled latents through ae decoder
        ims_recon_samp = np.zeros((0, n_channels * y_pix, x_pix))
        i_trial = 0
        while ims_recon_samp.shape[0] < max_frames:
            recon = ae_model.decoding(
                torch.tensor(model_output['latents_gen'][i_trial]).float(),
                None, None).cpu().detach().numpy()
            recon = np.concatenate(
                [recon[:, i] for i in range(recon.shape[1])], axis=1)
            zero_frames = np.zeros((n_buffer, n_channels * y_pix,
                                    x_pix))  # add a few black frames
            ims_recon_samp = np.concatenate(
                (ims_recon_samp, recon, zero_frames), axis=0)
            i_trial += 1

        make_real_vs_sampled_movies(ims_recon,
                                    ims_recon_samp,
                                    conditional=conditional,
                                    save_file=save_file,
                                    frame_rate=frame_rate)

    if output_type == 'both' or output_type == 'plot':

        i_trial = 0
        latents = model_output['latents'][dtype][i_trial][:max_frames]
        states = model_output['states'][dtype][i_trial][:max_frames]
        latents_samp = model_output['latents_gen'][i_trial][:max_frames]
        if not conditional:
            states_samp = model_output['states_gen'][i_trial][:max_frames]
        else:
            states_samp = []

        fig = plot_real_vs_sampled(latents,
                                   latents_samp,
                                   states,
                                   states_samp,
                                   save_file=save_file,
                                   xtick_locs=xtick_locs,
                                   frame_rate=frame_rate_beh,
                                   format=format)

    if output_type == 'movie':
        return None
    elif output_type == 'both' or output_type == 'plot':
        return fig
    else:
        raise ValueError('"%s" is an invalid output_type' % output_type)
示例#2
0
def get_model_latents_states(hparams,
                             version,
                             sess_idx=0,
                             return_samples=0,
                             cond_sampling=False,
                             dtype='test',
                             dtypes=['train', 'val', 'test'],
                             rng_seed=0):
    """Return arhmm defined in :obj:`hparams` with associated latents and states.

    Can also return sampled latents and states.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify an arhmm
    version : :obj:`str` or :obj:`int`
        test tube model version (can be 'best')
    sess_idx : :obj:`int`, optional
        session index into data generator
    return_samples : :obj:`int`, optional
        number of trials to sample from model
    cond_sampling : :obj:`bool`, optional
        if :obj:`True` return samples conditioned on most likely state sequence; else return
        unconditioned samples
    dtype : :obj:`str`, optional
        trial type to use for conditonal sampling; 'train' | 'val' | 'test'
    dtypes : :obj:`array-like`, optional
        trial types for which to collect latents and states
    rng_seed : :obj:`int`, optional
        random number generator seed to control sampling

    Returns
    -------
    :obj:`dict`
        - 'model' (:obj:`ssm.HMM` object)
        - 'latents' (:obj:`dict`): latents from train, val and test trials
        - 'states' (:obj:`dict`): states from train, val and test trials
        - 'trial_idxs' (:obj:`dict`): trial indices from train, val and test trials
        - 'latents_gen' (:obj:`list`)
        - 'states_gen' (:obj:`list`)

    """
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import experiment_exists
    from behavenet.fitting.utils import get_best_model_version
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)

    # get version/model
    if version == 'best':
        version = get_best_model_version(hparams['expt_dir'],
                                         measure='val_loss',
                                         best_def='max')[0]
    else:
        _, version = experiment_exists(hparams, which_version=True)
    if version is None:
        raise FileNotFoundError(
            'Could not find the specified model version in %s' %
            hparams['expt_dir'])

    # load model
    model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version,
                              'best_val_model.pt')
    with open(model_file, 'rb') as f:
        hmm = pickle.load(f)

    # load latents/labels
    if hparams['model_class'].find('labels') > -1:
        from behavenet.data.utils import load_labels_like_latents
        all_latents = load_labels_like_latents(hparams, sess_ids, sess_idx)
    else:
        _, latents_file = get_transforms_paths('ae_latents', hparams,
                                               sess_ids[sess_idx])
        with open(latents_file, 'rb') as f:
            all_latents = pickle.load(f)

    # collect inferred latents/states
    trial_idxs = {}
    latents = {}
    states = {}
    for data_type in dtypes:
        trial_idxs[data_type] = all_latents['trials'][data_type]
        latents[data_type] = [
            all_latents['latents'][i_trial]
            for i_trial in trial_idxs[data_type]
        ]
        states[data_type] = [
            hmm.most_likely_states(x) for x in latents[data_type]
        ]

    # collect sampled latents/states
    if return_samples > 0:
        states_gen = []
        np.random.seed(rng_seed)
        if cond_sampling:
            n_latents = latents[dtype][0].shape[1]
            latents_gen = [
                np.zeros((len(state_seg), n_latents))
                for state_seg in states[dtype]
            ]
            for i_seg, state_seg in enumerate(states[dtype]):
                for i_t in range(len(state_seg)):
                    if i_t >= 1:
                        latents_gen[i_seg][i_t] = hmm.observations.sample_x(
                            states[dtype][i_seg][i_t],
                            latents_gen[i_seg][:i_t],
                            input=np.zeros(0))
                    else:
                        latents_gen[i_seg][i_t] = hmm.observations.sample_x(
                            states[dtype][i_seg][i_t],
                            latents[dtype][i_seg][0].reshape((1, n_latents)),
                            input=np.zeros(0))
        else:
            latents_gen = []
            offset = 200
            for i in range(return_samples):
                these_states_gen, these_latents_gen = hmm.sample(
                    latents[dtype][0].shape[0] + offset)
                latents_gen.append(these_latents_gen[offset:])
                states_gen.append(these_states_gen[offset:])
    else:
        latents_gen = []
        states_gen = []

    return_dict = {
        'model': hmm,
        'latents': latents,
        'states': states,
        'trial_idxs': trial_idxs,
        'latents_gen': latents_gen,
        'states_gen': states_gen,
    }
    return return_dict
示例#3
0
def make_syllable_movies_wrapper(hparams,
                                 save_file,
                                 sess_idx=0,
                                 dtype='test',
                                 max_frames=400,
                                 frame_rate=10,
                                 min_threshold=0,
                                 n_buffer=5,
                                 n_pre_frames=3,
                                 n_rows=None,
                                 single_syllable=None):
    """Present video clips of each individual syllable in separate panels.

    This is a high-level function that loads the arhmm model described in the hparams dictionary
    and produces the necessary states/video frames.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify an arhmm
    save_file : :obj:`str`
        full save file (path and filename)
    sess_idx : :obj:`int`, optional
        session index into data generator
    dtype : :obj:`str`, optional
        types of trials to make video with; 'train' | 'val' | 'test'
    max_frames : :obj:`int`, optional
        maximum number of frames to animate
    frame_rate : :obj:`float`, optional
        frame rate of saved movie
    min_threshold : :obj:`int`, optional
        minimum number of frames in a syllable run to be considered for movie
    n_buffer : :obj:`int`
        number of blank frames between syllable instances
    n_pre_frames : :obj:`int`
        number of behavioral frames to precede each syllable instance
    n_rows : :obj:`int` or :obj:`NoneType`
        number of rows in output movie
    single_syllable : :obj:`int` or :obj:`NoneType`
        choose only a single state for movie

    """
    from behavenet.data.data_generator import ConcatSessionsGenerator
    from behavenet.data.utils import get_data_generator_inputs
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import experiment_exists
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    # load images, latents, and states
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)
    hparams['load_videos'] = True
    hparams, signals, transforms, paths = get_data_generator_inputs(
        hparams, sess_ids)
    data_generator = ConcatSessionsGenerator(
        hparams['data_dir'],
        sess_ids,
        signals_list=[signals[sess_idx]],
        transforms_list=[transforms[sess_idx]],
        paths_list=[paths[sess_idx]],
        device='cpu',
        as_numpy=True,
        batch_load=False,
        rng_seed=hparams['rng_seed_data'])
    ims_orig = data_generator.datasets[sess_idx].data['images']
    del data_generator  # free up memory

    # get tt version number
    _, version = experiment_exists(hparams, which_version=True)
    print('producing syllable videos for arhmm %s' % version)
    # load latents/labels
    if hparams['model_class'].find('labels') > -1:
        from behavenet.data.utils import load_labels_like_latents
        latents = load_labels_like_latents(hparams, sess_ids, sess_idx)
    else:
        _, latents_file = get_transforms_paths('ae_latents', hparams,
                                               sess_ids[sess_idx])
        with open(latents_file, 'rb') as f:
            latents = pickle.load(f)
    trial_idxs = latents['trials'][dtype]
    # load model
    model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version,
                              'best_val_model.pt')
    with open(model_file, 'rb') as f:
        hmm = pickle.load(f)
    # infer discrete states
    states = [
        hmm.most_likely_states(latents['latents'][s])
        for s in latents['trials'][dtype]
    ]
    if len(states) == 0:
        raise ValueError('No latents for dtype=%s' % dtype)

    # find runs of discrete states; state indices is a list, each entry of which is a np array with
    # shape (n_state_instances, 3), where the 3 values are:
    # chunk_idx, chunk_start_idx, chunk_end_idx
    # chunk_idx is in [0, n_chunks], and indexes trial_idxs
    state_indices = get_discrete_chunks(states, include_edges=True)
    K = len(state_indices)

    # get all example over minimum state length threshold
    over_threshold_instances = [[] for _ in range(K)]
    for i_state in range(K):
        if state_indices[i_state].shape[0] > 0:
            state_lens = np.diff(state_indices[i_state][:, 1:3], axis=1)
            over_idxs = state_lens > min_threshold
            over_threshold_instances[i_state] = state_indices[i_state][
                over_idxs[:, 0]]
            np.random.shuffle(
                over_threshold_instances[i_state])  # shuffle instances

    make_syllable_movies(ims_orig=ims_orig,
                         state_list=over_threshold_instances,
                         trial_idxs=trial_idxs,
                         save_file=save_file,
                         max_frames=max_frames,
                         frame_rate=frame_rate,
                         n_buffer=n_buffer,
                         n_pre_frames=n_pre_frames,
                         n_rows=n_rows,
                         single_syllable=single_syllable)
示例#4
0
def get_transforms_paths(data_type, hparams, sess_id):
    """Helper function for generating session-specific transforms and paths.

    Parameters
    ----------
    data_type : :obj:`str`
        'neural' | 'ae_latents' | 'arhmm_states' | 'neural_ae_predictions' |
        'neural_arhmm_predictions'
    hparams : :obj:`dict`
        - required keys for :obj:`data_type=neural`: 'neural_type', 'neural_thresh'
        - required keys for :obj:`data_type=ae_latents`: 'ae_experiment_name', 'ae_model_type', 'n_ae_latents', 'ae_version' or 'ae_latents_file'; this last option defines either the specific ae version (as 'best' or an int) or a path to a specific ae latents pickle file.
        - required keys for :obj:`data_type=arhmm_states`: 'arhmm_experiment_name', 'n_arhmm_states', 'kappa', 'noise_type', 'n_ae_latents', 'arhmm_version' or 'arhmm_states_file'; this last option defines either the specific arhmm version (as 'best' or an int) or a path to a specific ae latents pickle file.
        - required keys for :obj:`data_type=neural_ae_predictions`: 'neural_ae_experiment_name', 'neural_ae_model_type', 'neural_ae_version' or 'ae_predictions_file' plus keys for neural and ae_latents data types.
        - required keys for :obj:`data_type=neural_arhmm_predictions`: 'neural_arhmm_experiment_name', 'neural_arhmm_model_type', 'neural_arhmm_version' or 'arhmm_predictions_file', plus keys for neural and arhmm_states data types.
    sess_id : :obj:`dict`
        each list entry is a session-specific dict with keys 'lab', 'expt', 'animal', 'session'

    Returns
    -------
    :obj:`tuple`
        - hparams (:obj:`dict`): updated with model-specific information like input and output size
        - signals (:obj:`list`): session-specific signals
        - transforms (:obj:`list`): session-specific transforms
        - paths (:obj:`list`): session-specific paths

    """

    from behavenet.data.transforms import SelectIdxs
    from behavenet.data.transforms import Threshold
    from behavenet.data.transforms import ZScore
    from behavenet.data.transforms import BlockShuffle
    from behavenet.data.transforms import Compose
    from behavenet.fitting.utils import get_best_model_version
    from behavenet.fitting.utils import get_expt_dir

    # check for multisession by comparing hparams and sess_id
    hparams_ = {
        key: hparams[key]
        for key in ['lab', 'expt', 'animal', 'session']
    }
    if sess_id is None:
        sess_id = hparams_

    sess_id_str = str('%s_%s_%s_%s_' % (sess_id['lab'], sess_id['expt'],
                                        sess_id['animal'], sess_id['session']))

    if data_type == 'neural':

        path = os.path.join(hparams['data_dir'], sess_id['lab'],
                            sess_id['expt'], sess_id['animal'],
                            sess_id['session'], 'data.hdf5')

        transforms_ = []

        # filter neural data by region
        if hparams.get('subsample_regions', 'none') != 'none':
            # get region and indices
            sampling = hparams['subsample_regions']
            region_name = hparams['region']
            regions = get_region_list(hparams)
            if sampling == 'single':
                idxs = regions[region_name]
            elif sampling == 'loo':
                idxs = []
                for reg_name, reg_idxs in regions.items():
                    if reg_name != region_name:
                        idxs.append(reg_idxs)
                idxs = np.concatenate(idxs)
            else:
                raise ValueError('"%s" is an invalid region sampling option' %
                                 sampling)
            transforms_.append(
                SelectIdxs(idxs, str('%s-%s' % (region_name, sampling))))

        # filter neural data by activity
        if hparams['neural_type'] == 'spikes':
            if hparams['neural_thresh'] > 0:
                transforms_.append(
                    Threshold(threshold=hparams['neural_thresh'],
                              bin_size=hparams['neural_bin_size']))
        elif hparams['neural_type'] == 'ca':
            if hparams['model_type'][-6:] != 'neural':
                # don't zscore if predicting calcium activity
                transforms_.append(ZScore())
        else:
            raise ValueError('"%s" is an invalid neural type' %
                             hparams['neural_type'])

        # compose filters
        if len(transforms_) == 0:
            transform = None
        else:
            transform = Compose(transforms_)

    elif data_type == 'ae_latents' or data_type == 'latents':
        ae_dir = get_expt_dir(hparams,
                              model_class='ae',
                              expt_name=hparams['ae_experiment_name'],
                              model_type=hparams['ae_model_type'])

        transform = None

        if 'ae_latents_file' in hparams:
            path = hparams['ae_latents_file']
        else:
            if 'ae_version' in hparams and isinstance(hparams['ae_version'],
                                                      int):
                ae_version = str('version_%i' % hparams['ae_version'])
            else:
                ae_version = 'version_%i' % get_best_model_version(
                    ae_dir, 'val_loss')[0]
            ae_latents = str('%slatents.pkl' % sess_id_str)
            path = os.path.join(ae_dir, ae_version, ae_latents)

    elif data_type == 'arhmm_states' or data_type == 'states':

        arhmm_dir = get_expt_dir(hparams,
                                 model_class='arhmm',
                                 expt_name=hparams['arhmm_experiment_name'])

        if hparams.get('shuffle_rng_seed') is not None:
            transform = BlockShuffle(hparams['shuffle_rng_seed'])
        else:
            transform = None

        if 'arhmm_state_file' in hparams:
            path = hparams['arhmm_state_file']
        else:
            if 'arhmm_version' in hparams and isinstance(
                    hparams['arhmm_version'], int):
                arhmm_version = str('version_%i' % hparams['arhmm_version'])
            else:
                arhmm_version = 'version_%i' % get_best_model_version(
                    arhmm_dir, 'val_loss', best_def='max')[0]
            arhmm_states = str('%sstates.pkl' % sess_id_str)
            path = os.path.join(arhmm_dir, arhmm_version, arhmm_states)

    elif data_type == 'neural_ae_predictions' or data_type == 'ae_predictions':

        neural_ae_dir = get_expt_dir(
            hparams,
            model_class='neural-ae',
            expt_name=hparams['neural_ae_experiment_name'],
            model_type=hparams['neural_ae_model_type'])

        transform = None
        if 'ae_predictions_file' in hparams:
            path = hparams['ae_predictions_file']
        else:
            if 'neural_ae_version' in hparams and isinstance(
                    hparams['neural_ae_version'], int):
                neural_ae_version = str('version_%i' %
                                        hparams['neural_ae_version'])
            else:
                neural_ae_version = 'version_%i' % get_best_model_version(
                    neural_ae_dir, 'val_loss')[0]
            neural_ae_predictions = str('%spredictions.pkl' % sess_id_str)
            path = os.path.join(neural_ae_dir, neural_ae_version,
                                neural_ae_predictions)

    elif data_type == 'neural_arhmm_predictions' or data_type == 'arhmm_predictions':

        neural_arhmm_dir = get_expt_dir(
            hparams,
            model_class='neural-arhmm',
            expt_name=hparams['neural_arhmm_experiment_name'],
            model_type=hparams['neural_arhmm_model_type'])

        transform = None
        if 'arhmm_predictions_file' in hparams:
            path = hparams['arhmm_predictions_file']
        else:
            if 'neural_arhmm_version' in hparams and \
                    isinstance(hparams['neural_arhmm_version'], int):
                neural_arhmm_version = str('version_%i' %
                                           hparams['neural_arhmm_version'])
            else:
                neural_arhmm_version = 'version_%i' % get_best_model_version(
                    neural_arhmm_dir, 'val_loss')[0]
            neural_arhmm_predictions = str('%spredictions.pkl' % sess_id_str)
            path = os.path.join(neural_arhmm_dir, neural_arhmm_version,
                                neural_arhmm_predictions)

    else:
        raise ValueError('"%s" is an invalid data_type' % data_type)

    return transform, path
示例#5
0
def get_r2s_by_trial(hparams, model_types):
    """For a given session, load R^2 metrics from all decoders defined by hparams.

    Parameters
    ----------

    hparams : :obj:`dict`
        needs to contain enough information to specify decoders
    model_types : :obj:`list` of :obj:`strs`
        'mlp' | 'mlp-mv' | 'lstm'

    Returns
    -------
    :obj:`pd.DataFrame`
        pandas dataframe of decoder validation metrics

    """

    dataset = _get_dataset_str(hparams)
    region_names = get_region_list(hparams)

    metrics = []
    model_idx = 0
    model_counter = 0
    for region in region_names:
        hparams['region'] = region
        for model_type in model_types:

            hparams['session_dir'], _ = get_session_dir(
                hparams, session_source=hparams.get('all_source', 'save'))
            expt_dir = get_expt_dir(hparams,
                                    model_type=model_type,
                                    model_class=hparams['model_class'],
                                    expt_name=hparams['experiment_name'])

            # gather all versions
            try:
                versions = get_subdirs(expt_dir)
            except Exception:
                print('No models in %s; skipping' % expt_dir)

            # load csv files with model metrics (saved out from test tube)
            for i, version in enumerate(versions):
                # read metrics csv file
                model_dir = os.path.join(expt_dir, version)
                try:
                    metric = pd.read_csv(os.path.join(model_dir,
                                                      'metrics.csv'))
                    model_counter += 1
                except FileNotFoundError:
                    continue
                with open(os.path.join(model_dir, 'meta_tags.pkl'), 'rb') as f:
                    hparams = pickle.load(f)
                # append model info to metrics ()
                version_num = version[8:]
                metric['version'] = str('version_%i' % model_idx + version_num)
                metric['region'] = region
                metric['dataset'] = dataset
                metric['model_type'] = model_type
                for key, val in hparams.items():
                    if isinstance(val, (str, int, float)):
                        metric[key] = val
                metrics.append(metric)

            model_idx += 10000  # assumes no more than 10k model versions/expt
    # put everything in pandas dataframe
    metrics_df = pd.concat(metrics, sort=False)
    return metrics_df
示例#6
0
def get_transforms_paths(data_type, hparams, sess_id, check_splits=True):
    """Helper function for generating session-specific transforms and paths.

    Parameters
    ----------
    data_type : :obj:`str`
        'neural' | 'ae_latents' | 'arhmm_states' | 'neural_ae_predictions' |
        'neural_arhmm_predictions'
    hparams : :obj:`dict`
        - required keys for :obj:`data_type=neural`: 'neural_type', 'neural_thresh'
        - required keys for :obj:`data_type=ae_latents`: 'ae_experiment_name', 'ae_model_type',
          'n_ae_latents', 'ae_version' or 'ae_latents_file'; this last option defines either the
          specific ae version (as 'best' or an int) or a path to a specific ae latents pickle file.
        - required keys for :obj:`data_type=arhmm_states`: 'arhmm_experiment_name',
          'n_arhmm_states', 'kappa', 'noise_type', 'n_ae_latents', 'arhmm_version' or
          'arhmm_states_file'; this last option defines either the specific arhmm version (as
          'best' or an int) or a path to a specific ae latents pickle file.
        - required keys for :obj:`data_type=neural_ae_predictions`: 'neural_ae_experiment_name',
          'neural_ae_model_type', 'neural_ae_version' or 'ae_predictions_file' plus keys for neural
          and ae_latents data types.
        - required keys for :obj:`data_type=neural_arhmm_predictions`:
          'neural_arhmm_experiment_name', 'neural_arhmm_model_type', 'neural_arhmm_version' or
          'arhmm_predictions_file', plus keys for neural and arhmm_states data types.
    sess_id : :obj:`dict`
        each list entry is a session-specific dict with keys 'lab', 'expt', 'animal', 'session'
    check_splits : :obj:`bool`, optional
        check data splits and data rng seed between hparams and loaded model outputs (e.g. latents)

    Returns
    -------
    :obj:`tuple`
        - transform (:obj:`behavenet.data.transforms.Transform` object): session-specific transform
        - path (:obj:`str`): session-specific path

    """

    from behavenet.data.transforms import BlockShuffle
    from behavenet.data.transforms import Compose
    from behavenet.data.transforms import MotionEnergy
    from behavenet.data.transforms import SelectIdxs
    from behavenet.data.transforms import Threshold
    from behavenet.data.transforms import ZScore
    from behavenet.fitting.utils import get_best_model_version
    from behavenet.fitting.utils import get_expt_dir

    # check for multisession by comparing hparams and sess_id
    hparams_ = {key: hparams[key] for key in ['lab', 'expt', 'animal', 'session']}
    if sess_id is None:
        sess_id = hparams_

    sess_id_str = str('%s_%s_%s_%s_' % (
        sess_id['lab'], sess_id['expt'], sess_id['animal'], sess_id['session']))

    if data_type == 'neural':

        check_splits = False

        path = os.path.join(
            hparams['data_dir'], sess_id['lab'], sess_id['expt'], sess_id['animal'],
            sess_id['session'], 'data.hdf5')

        transforms_ = []

        # filter neural data by indices (regions, cell types, etc)
        if hparams.get('subsample_method', 'none') != 'none':
            # get indices
            sampling = hparams['subsample_method']
            idxs_name = hparams['subsample_idxs_name']
            idxs_dict = get_region_list(hparams)
            if sampling == 'single':
                idxs = idxs_dict[idxs_name]
            elif sampling == 'loo':
                idxs = []
                for idxs_key, idxs_val in idxs_dict.items():
                    if idxs_key != idxs_name:
                        idxs.append(idxs_val)
                idxs = np.concatenate(idxs)
            else:
                raise ValueError('"%s" is an invalid index sampling option' % sampling)
            transforms_.append(SelectIdxs(idxs, str('%s-%s' % (idxs_name, sampling))))

        # filter neural data by activity
        if hparams['neural_type'] == 'spikes':
            if hparams['neural_thresh'] > 0:
                transforms_.append(Threshold(
                    threshold=hparams['neural_thresh'],
                    bin_size=hparams['neural_bin_size']))
        elif hparams['neural_type'] == 'ca':
            if hparams['model_type'][-6:] != 'neural':
                # don't zscore if predicting calcium activity
                transforms_.append(ZScore())
        elif hparams['neural_type'] == 'ca-zscored':
            pass
        else:
            raise ValueError('"%s" is an invalid neural type' % hparams['neural_type'])

        # compose filters
        if len(transforms_) == 0:
            transform = None
        else:
            transform = Compose(transforms_)

    elif data_type == 'ae_latents' or data_type == 'latents' \
            or data_type == 'ae_latents_me' or data_type == 'latents_me':

        if data_type == 'ae_latents_me' or data_type == 'latents_me':
            transform = MotionEnergy()
        else:
            transform = None

        if 'ae_latents_file' in hparams:
            path = hparams['ae_latents_file']
        else:
            ae_dir = get_expt_dir(
                hparams, model_class=hparams['ae_model_class'],
                expt_name=hparams['ae_experiment_name'],
                model_type=hparams['ae_model_type'])
            if 'ae_version' in hparams and hparams['ae_version'] != 'best':
                # json args read as strings
                if isinstance(hparams['ae_version'], str):
                    hparams['ae_version'] = int(hparams['ae_version'])
                ae_version = str('version_%i' % hparams['ae_version'])
            else:
                ae_version = 'version_%i' % get_best_model_version(ae_dir, 'val_loss')[0]
            ae_latents = str('%slatents.pkl' % sess_id_str)
            path = os.path.join(ae_dir, ae_version, ae_latents)

    elif data_type == 'arhmm_states' or data_type == 'states':

        if hparams.get('shuffle_rng_seed') is not None:
            transform = BlockShuffle(hparams['shuffle_rng_seed'])
        else:
            transform = None

        if 'arhmm_states_file' in hparams:
            path = hparams['arhmm_states_file']
        else:
            arhmm_dir = get_expt_dir(
                hparams, model_class='arhmm', expt_name=hparams['arhmm_experiment_name'])
            if 'arhmm_version' in hparams and isinstance(hparams['arhmm_version'], int):
                arhmm_version = str('version_%i' % hparams['arhmm_version'])
            else:
                arhmm_version = 'version_%i' % get_best_model_version(
                    arhmm_dir, 'val_loss', best_def='min')[0]
            arhmm_states = str('%sstates.pkl' % sess_id_str)
            path = os.path.join(arhmm_dir, arhmm_version, arhmm_states)

    elif data_type == 'neural_ae_predictions' or data_type == 'ae_predictions':

        transform = None

        if 'ae_predictions_file' in hparams:
            path = hparams['ae_predictions_file']
        else:
            neural_ae_dir = get_expt_dir(
                hparams, model_class='neural-ae',
                expt_name=hparams['neural_ae_experiment_name'],
                model_type=hparams['neural_ae_model_type'])
            if 'neural_ae_version' in hparams and isinstance(hparams['neural_ae_version'], int):
                neural_ae_version = str('version_%i' % hparams['neural_ae_version'])
            else:
                neural_ae_version = 'version_%i' % get_best_model_version(
                    neural_ae_dir, 'val_loss')[0]
            neural_ae_predictions = str('%spredictions.pkl' % sess_id_str)
            path = os.path.join(neural_ae_dir, neural_ae_version, neural_ae_predictions)

    elif data_type == 'neural_arhmm_predictions' or data_type == 'arhmm_predictions':

        transform = None

        if 'arhmm_predictions_file' in hparams:
            path = hparams['arhmm_predictions_file']
        else:
            neural_arhmm_dir = get_expt_dir(
                hparams, model_class='neural-arhmm',
                expt_name=hparams['neural_arhmm_experiment_name'],
                model_type=hparams['neural_arhmm_model_type'])
            if 'neural_arhmm_version' in hparams and \
                    isinstance(hparams['neural_arhmm_version'], int):
                neural_arhmm_version = str('version_%i' % hparams['neural_arhmm_version'])
            else:
                neural_arhmm_version = 'version_%i' % get_best_model_version(
                    neural_arhmm_dir, 'val_loss')[0]
            neural_arhmm_predictions = str('%spredictions.pkl' % sess_id_str)
            path = os.path.join(neural_arhmm_dir, neural_arhmm_version, neural_arhmm_predictions)

    else:
        raise ValueError('"%s" is an invalid data_type' % data_type)

    # check training data split is the same
    if check_splits:
        check_same_training_split(path, hparams)

    return transform, path
示例#7
0
def load_metrics_csv_as_df(hparams,
                           lab,
                           expt,
                           metrics_list,
                           test=False,
                           version='best'):
    """Load metrics csv file and return as a pandas dataframe for easy plotting.

    Parameters
    ----------
    hparams : :obj:`dict`
        requires `sessions_csv`, `multisession`, `lab`, `expt`, `animal` and `session`
    lab : :obj:`str`
        for `get_lab_example`
    expt : :obj:`str`
        for `get_lab_example`
    metrics_list : :obj:`list`
        names of metrics to pull from csv; do not prepend with 'tr', 'val', or 'test'
    test : :obj:`bool`
        True to only return test values (computed once at end of training)
    version: :obj:`str`
        `best` to find best model in tt expt, None to find model with hyperparams defined in
        `hparams`, int to load specific model

    Returns
    -------
    :obj:`pandas.DataFrame` object

    """

    # programmatically fill out other hparams options
    get_lab_example(hparams, lab, expt)
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)

    # find metrics csv file
    if version is 'best':
        version = get_best_model_version(hparams['expt_dir'])[0]
    elif isinstance(version, int):
        version = version
    else:
        _, version = experiment_exists(hparams, which_version=True)
    version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % version)
    metric_file = os.path.join(version_dir, 'metrics.csv')
    metrics = pd.read_csv(metric_file)

    # collect data from csv file
    sess_ids = read_session_info_from_csv(
        os.path.join(version_dir, 'session_info.csv'))
    sess_ids_strs = []
    for sess_id in sess_ids:
        sess_ids_strs.append(
            str('%s/%s' % (sess_id['animal'], sess_id['session'])))
    metrics_df = []
    for i, row in metrics.iterrows():
        dataset = 'all' if row['dataset'] == -1 else sess_ids_strs[
            row['dataset']]
        if test:
            test_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'test'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **test_dict, 'loss': metric,
                            'val': row['test_%s' % metric]
                        },
                        index=[0]))
        else:
            # make dict for val data
            val_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'val'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **val_dict, 'loss': metric,
                            'val': row['val_%s' % metric]
                        },
                        index=[0]))
            # NOTE: grayed out lines are old version that returns a single dataframe row containing
            # all losses per epoch; new way creates one row per loss, making it easy to use with
            # seaborn's FacetGrid object for multi-axis plotting for metric in metrics_list:
            #     val_dict[metric] = row['val_%s' % metric]
            # metrics_df.append(pd.DataFrame(val_dict, index=[0]))
            # make dict for train data
            tr_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'train'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **tr_dict, 'loss': metric,
                            'val': row['tr_%s' % metric]
                        },
                        index=[0]))
            # for metric in metrics_list:
            #     tr_dict[metric] = row['tr_%s' % metric]
            # metrics_df.append(pd.DataFrame(tr_dict, index=[0]))
    return pd.concat(metrics_df, sort=True)
    def test_get_expt_dir(self):

        hparams = {
            'data_dir': 'ddir', 'save_dir': 'sdir', 'lab': 'lab0', 'expt': 'expt0',
            'animal': 'animal0', 'session': 'session-00'}
        session_dir = os.path.join(
            hparams['data_dir'], hparams['lab'], hparams['expt'], hparams['animal'],
            hparams['session'])
        hparams['session_dir'] = session_dir

        # -------------------------
        # ae
        # -------------------------
        hparams['model_class'] = 'ae'
        hparams['model_type'] = 'conv'
        hparams['n_ae_latents'] = 8
        hparams['experiment_name'] = 'tt_expt'
        model_path = os.path.join(
            session_dir, hparams['model_class'], hparams['model_type'],
            '%02i_latents' % hparams['n_ae_latents'], hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        expt_dir = utils.get_expt_dir(hparams, model_class=None, model_type=None, expt_name=None)
        assert expt_dir == model_path

        # multisession
        hparams['save_dir'] = self.tmpdir
        hparams['ae_multisession'] = 0
        model_path = os.path.join(
            hparams['save_dir'], hparams['lab'], hparams['expt'], hparams['animal'],
            'multisession-%02i' % hparams['ae_multisession'], hparams['model_class'],
            hparams['model_type'], '%02i_latents' % hparams['n_ae_latents'],
            hparams['experiment_name'])
        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path
        hparams['ae_multisession'] = None
        hparams['save_dir'] = 'sdir'

        # -------------------------
        # vae
        # -------------------------
        hparams['model_class'] = 'vae'
        hparams['model_type'] = 'conv'
        hparams['n_ae_latents'] = 10
        hparams['experiment_name'] = 'tt_expt'
        model_path = os.path.join(
            session_dir, hparams['model_class'], hparams['model_type'],
            '%02i_latents' % hparams['n_ae_latents'], hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        expt_dir = utils.get_expt_dir(hparams, model_class=None, model_type=None, expt_name=None)
        assert expt_dir == model_path

        # -------------------------
        # beta-tcvae
        # -------------------------
        hparams['model_class'] = 'beta-tcvae'
        hparams['model_type'] = 'conv'
        hparams['n_ae_latents'] = 10
        hparams['experiment_name'] = 'tt_expt'
        model_path = os.path.join(
            session_dir, hparams['model_class'], hparams['model_type'],
            '%02i_latents' % hparams['n_ae_latents'], hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        expt_dir = utils.get_expt_dir(hparams, model_class=None, model_type=None, expt_name=None)
        assert expt_dir == model_path

        # -------------------------
        # cond-vae
        # -------------------------
        hparams['model_class'] = 'cond-vae'
        hparams['model_type'] = 'conv'
        hparams['n_ae_latents'] = 8
        hparams['experiment_name'] = 'tt_expt'
        model_path = os.path.join(
            session_dir, hparams['model_class'], hparams['model_type'],
            '%02i_latents' % hparams['n_ae_latents'], hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        expt_dir = utils.get_expt_dir(hparams, model_class=None, model_type=None, expt_name=None)
        assert expt_dir == model_path

        # -------------------------
        # cond-ae [-msp]
        # -------------------------
        hparams['model_class'] = 'cond-ae'
        hparams['model_type'] = 'conv'
        hparams['n_ae_latents'] = 8
        hparams['experiment_name'] = 'tt_expt'
        model_path = os.path.join(
            session_dir, hparams['model_class'], hparams['model_type'],
            '%02i_latents' % hparams['n_ae_latents'], hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        hparams['model_class'] = 'cond-ae-msp'
        model_path = os.path.join(
            session_dir, hparams['model_class'], hparams['model_type'],
            '%02i_latents' % hparams['n_ae_latents'], hparams['experiment_name'])
        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        # -------------------------
        # ps-vae
        # -------------------------
        hparams['model_class'] = 'ps-vae'
        hparams['model_type'] = 'conv'
        hparams['n_ae_latents'] = 10
        hparams['experiment_name'] = 'tt_expt'
        model_path = os.path.join(
            session_dir, hparams['model_class'], hparams['model_type'],
            '%02i_latents' % hparams['n_ae_latents'], hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        expt_dir = utils.get_expt_dir(hparams, model_class=None, model_type=None, expt_name=None)
        assert expt_dir == model_path

        # -------------------------
        # neural-ae/neural-ae-me/ae-neural
        # -------------------------
        hparams['model_class'] = 'neural-ae'
        hparams['model_type'] = 'mlp'
        hparams['n_ae_latents'] = 8
        hparams['experiment_name'] = 'tt_expt'
        model_path = os.path.join(
            session_dir, hparams['model_class'], '%02i_latents' % hparams['n_ae_latents'],
            hparams['model_type'], 'all', hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        hparams['model_class'] = 'neural-ae-me'
        model_path = os.path.join(
            session_dir, hparams['model_class'], '%02i_latents' % hparams['n_ae_latents'],
            hparams['model_type'], 'all', hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        hparams['model_class'] = 'ae-neural'
        model_path = os.path.join(
            session_dir, hparams['model_class'], '%02i_latents' % hparams['n_ae_latents'],
            hparams['model_type'], 'all', hparams['experiment_name'])
        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        # -------------------------
        # neural-labels/labels-neural
        # -------------------------
        hparams['model_class'] = 'neural-labels'
        hparams['model_type'] = 'mlp'
        hparams['experiment_name'] = 'tt_expt'
        model_path = os.path.join(
            session_dir, hparams['model_class'], hparams['model_type'], 'all',
            hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        hparams['model_class'] = 'labels-neural'
        model_path = os.path.join(
            session_dir, hparams['model_class'], hparams['model_type'], 'all',
            hparams['experiment_name'])
        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        # -------------------------
        # neural-arhmm/arhmm-neural
        # -------------------------
        hparams['model_class'] = 'neural-arhmm'
        hparams['model_type'] = 'mlp'
        hparams['n_ae_latents'] = 8
        hparams['n_arhmm_states'] = 10
        hparams['transitions'] = 'stationary'
        hparams['experiment_name'] = 'tt_expt'
        model_path = os.path.join(
            session_dir, hparams['model_class'], '%02i_latents' % hparams['n_ae_latents'],
            '%02i_states' % hparams['n_arhmm_states'], hparams['transitions'],
            hparams['model_type'], 'all', hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        hparams['model_class'] = 'arhmm-neural'
        model_path = os.path.join(
            session_dir, hparams['model_class'], '%02i_latents' % hparams['n_ae_latents'],
            '%02i_states' % hparams['n_arhmm_states'],
            hparams['transitions'], hparams['model_type'], 'all', hparams['experiment_name'])
        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        hparams['transitions'] = 'sticky'
        hparams['kappa'] = 100
        model_path = os.path.join(
            session_dir, hparams['model_class'], '%02i_latents' % hparams['n_ae_latents'],
            '%02i_states' % hparams['n_arhmm_states'],
            '%s_%.0e' % (hparams['transitions'], hparams['kappa']),
            hparams['model_type'], 'all', hparams['experiment_name'])
        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        # -------------------------
        # arhmm/hmm
        # -------------------------
        hparams['model_class'] = 'arhmm'
        hparams['n_ae_latents'] = 8
        hparams['n_arhmm_states'] = 10
        hparams['transitions'] = 'stationary'
        hparams['noise_type'] = 'gaussian'
        hparams['experiment_name'] = 'tt_expt'
        model_path = os.path.join(
            session_dir, hparams['model_class'], '%02i_latents' % hparams['n_ae_latents'],
            '%02i_states' % hparams['n_arhmm_states'], hparams['transitions'],
            hparams['noise_type'], hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        # multisession
        hparams['save_dir'] = self.tmpdir
        hparams['arhmm_multisession'] = 0
        model_path = os.path.join(
            hparams['save_dir'], hparams['lab'], hparams['expt'], hparams['animal'],
            'multisession-%02i' % hparams['arhmm_multisession'], hparams['model_class'],
            '%02i_latents' % hparams['n_ae_latents'], '%02i_states' % hparams['n_arhmm_states'],
            hparams['transitions'], hparams['noise_type'], hparams['experiment_name'])
        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path
        hparams['arhmm_multisession'] = None
        hparams['save_dir'] = 'sdir'

        # -------------------------
        # arhmm-labels
        # -------------------------
        hparams['model_class'] = 'arhmm-labels'
        hparams['n_arhmm_states'] = 10
        hparams['transitions'] = 'stationary'
        hparams['noise_type'] = 'studentst'
        hparams['experiment_name'] = 'tt_expt'
        model_path = os.path.join(
            session_dir, hparams['model_class'], '%02i_states' % hparams['n_arhmm_states'],
            hparams['transitions'], hparams['noise_type'], hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        # -------------------------
        # bayesian decoding
        # -------------------------
        hparams['model_class'] = 'bayesian-decoding'
        hparams['n_ae_latents'] = 8
        hparams['n_arhmm_states'] = 10
        hparams['transitions'] = 'stationary'
        hparams['noise_type'] = 'studentst'
        hparams['experiment_name'] = 'tt_expt'
        model_path = os.path.join(
            session_dir, hparams['model_class'], '%02i_latents' % hparams['n_ae_latents'],
            '%02i_states' % hparams['n_arhmm_states'], hparams['transitions'],
            hparams['noise_type'], 'all', hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        # -------------------------
        # labels-images
        # -------------------------
        hparams['model_class'] = 'labels-images'
        hparams['model_type'] = 'conv'
        hparams['experiment_name'] = 'tt_expt'
        model_path = os.path.join(
            session_dir, hparams['model_class'], hparams['model_type'], hparams['experiment_name'])

        expt_dir = utils.get_expt_dir(
            hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
            expt_name=hparams['experiment_name'])
        assert expt_dir == model_path

        # -------------------------
        # other
        # -------------------------
        hparams['model_class'] = 'testing'
        hparams['model_type'] = 'conv'
        hparams['experiment_name'] = 'tt_expt'
        with pytest.raises(ValueError):
            utils.get_expt_dir(
                hparams, model_class=hparams['model_class'], model_type=hparams['model_type'],
                expt_name=hparams['experiment_name'])