コード例 #1
0
def real_vs_sampled_wrapper(output_type,
                            hparams,
                            save_file,
                            sess_idx,
                            dtype='test',
                            conditional=True,
                            max_frames=400,
                            frame_rate=20,
                            n_buffer=5,
                            xtick_locs=None,
                            frame_rate_beh=None,
                            format='png'):
    """Produce movie with (AE) reconstructed video and sampled video.

    This is a high-level function that loads the model described in the hparams dictionary and
    produces the necessary state sequences/samples. The sampled video can be completely
    unconditional (states and latents are sampled) or conditioned on the most likely state
    sequence.

    Parameters
    ----------
    output_type : :obj:`str`
        'plot' | 'movie' | 'both'
    hparams : :obj:`dict`
        needs to contain enough information to specify an autoencoder
    save_file : :obj:`str`
        full save file (path and filename)
    sess_idx : :obj:`int`, optional
        session index into data generator
    dtype : :obj:`str`, optional
        types of trials to make plot/video with; 'train' | 'val' | 'test'
    conditional : :obj:`bool`
        conditional vs unconditional samples; for creating reconstruction title
    max_frames : :obj:`int`, optional
        maximum number of frames to animate
    frame_rate : :obj:`float`, optional
        frame rate of saved movie
    n_buffer : :obj:`int`
        number of blank frames between animated trials if more one are needed to reach
        :obj:`max_frames`
    xtick_locs : :obj:`array-like`, optional
        tick locations in bin values for plot
    frame_rate_beh : :obj:`float`, optional
        behavioral video framerate; to properly relabel xticks
    format : :obj:`str`, optional
        any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg'

    Returns
    -------
    :obj:`matplotlib.figure.Figure`
        matplotlib figure handle if :obj:`output_type='plot'` or :obj:`output_type='both'`, else
        nothing returned (movie is saved)

    """
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    # check input - cannot create sampled movies for arhmm-labels models (no mapping from labels to
    # frames)
    if hparams['model_class'].find('labels') > -1:
        if output_type == 'both' or output_type == 'movie':
            print(
                'warning: cannot create video with "arhmm-labels" model; producing plots'
            )
            output_type = 'plot'

    # load latents and states (observed and sampled)
    model_output = get_model_latents_states(hparams,
                                            '',
                                            sess_idx=sess_idx,
                                            return_samples=50,
                                            cond_sampling=conditional)

    if output_type == 'both' or output_type == 'movie':

        # load in AE decoder
        if hparams.get('ae_model_path', None) is not None:
            ae_model_file = os.path.join(hparams['ae_model_path'],
                                         'best_val_model.pt')
            ae_arch = pickle.load(
                open(os.path.join(hparams['ae_model_path'], 'meta_tags.pkl'),
                     'rb'))
        else:
            hparams['session_dir'], sess_ids = get_session_dir(hparams)
            hparams['expt_dir'] = get_expt_dir(hparams)
            _, latents_file = get_transforms_paths('ae_latents', hparams,
                                                   sess_ids[sess_idx])
            ae_model_file = os.path.join(os.path.dirname(latents_file),
                                         'best_val_model.pt')
            ae_arch = pickle.load(
                open(
                    os.path.join(os.path.dirname(latents_file),
                                 'meta_tags.pkl'), 'rb'))
        print('loading model from %s' % ae_model_file)
        ae_model = AE(ae_arch)
        ae_model.load_state_dict(
            torch.load(ae_model_file,
                       map_location=lambda storage, loc: storage))
        ae_model.eval()

        n_channels = ae_model.hparams['n_input_channels']
        y_pix = ae_model.hparams['y_pixels']
        x_pix = ae_model.hparams['x_pixels']

        # push observed latents through ae decoder
        ims_recon = np.zeros((0, n_channels * y_pix, x_pix))
        i_trial = 0
        while ims_recon.shape[0] < max_frames:
            recon = ae_model.decoding(
                torch.tensor(model_output['latents'][dtype][i_trial]).float(), None, None). \
                cpu().detach().numpy()
            recon = np.concatenate(
                [recon[:, i] for i in range(recon.shape[1])], axis=1)
            zero_frames = np.zeros((n_buffer, n_channels * y_pix,
                                    x_pix))  # add a few black frames
            ims_recon = np.concatenate((ims_recon, recon, zero_frames), axis=0)
            i_trial += 1

        # push sampled latents through ae decoder
        ims_recon_samp = np.zeros((0, n_channels * y_pix, x_pix))
        i_trial = 0
        while ims_recon_samp.shape[0] < max_frames:
            recon = ae_model.decoding(
                torch.tensor(model_output['latents_gen'][i_trial]).float(),
                None, None).cpu().detach().numpy()
            recon = np.concatenate(
                [recon[:, i] for i in range(recon.shape[1])], axis=1)
            zero_frames = np.zeros((n_buffer, n_channels * y_pix,
                                    x_pix))  # add a few black frames
            ims_recon_samp = np.concatenate(
                (ims_recon_samp, recon, zero_frames), axis=0)
            i_trial += 1

        make_real_vs_sampled_movies(ims_recon,
                                    ims_recon_samp,
                                    conditional=conditional,
                                    save_file=save_file,
                                    frame_rate=frame_rate)

    if output_type == 'both' or output_type == 'plot':

        i_trial = 0
        latents = model_output['latents'][dtype][i_trial][:max_frames]
        states = model_output['states'][dtype][i_trial][:max_frames]
        latents_samp = model_output['latents_gen'][i_trial][:max_frames]
        if not conditional:
            states_samp = model_output['states_gen'][i_trial][:max_frames]
        else:
            states_samp = []

        fig = plot_real_vs_sampled(latents,
                                   latents_samp,
                                   states,
                                   states_samp,
                                   save_file=save_file,
                                   xtick_locs=xtick_locs,
                                   frame_rate=frame_rate_beh,
                                   format=format)

    if output_type == 'movie':
        return None
    elif output_type == 'both' or output_type == 'plot':
        return fig
    else:
        raise ValueError('"%s" is an invalid output_type' % output_type)
コード例 #2
0
def make_syllable_movies_wrapper(hparams,
                                 save_file,
                                 sess_idx=0,
                                 dtype='test',
                                 max_frames=400,
                                 frame_rate=10,
                                 min_threshold=0,
                                 n_buffer=5,
                                 n_pre_frames=3,
                                 n_rows=None,
                                 single_syllable=None):
    """Present video clips of each individual syllable in separate panels.

    This is a high-level function that loads the arhmm model described in the hparams dictionary
    and produces the necessary states/video frames.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify an arhmm
    save_file : :obj:`str`
        full save file (path and filename)
    sess_idx : :obj:`int`, optional
        session index into data generator
    dtype : :obj:`str`, optional
        types of trials to make video with; 'train' | 'val' | 'test'
    max_frames : :obj:`int`, optional
        maximum number of frames to animate
    frame_rate : :obj:`float`, optional
        frame rate of saved movie
    min_threshold : :obj:`int`, optional
        minimum number of frames in a syllable run to be considered for movie
    n_buffer : :obj:`int`
        number of blank frames between syllable instances
    n_pre_frames : :obj:`int`
        number of behavioral frames to precede each syllable instance
    n_rows : :obj:`int` or :obj:`NoneType`
        number of rows in output movie
    single_syllable : :obj:`int` or :obj:`NoneType`
        choose only a single state for movie

    """
    from behavenet.data.data_generator import ConcatSessionsGenerator
    from behavenet.data.utils import get_data_generator_inputs
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import experiment_exists
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    # load images, latents, and states
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)
    hparams['load_videos'] = True
    hparams, signals, transforms, paths = get_data_generator_inputs(
        hparams, sess_ids)
    data_generator = ConcatSessionsGenerator(
        hparams['data_dir'],
        sess_ids,
        signals_list=[signals[sess_idx]],
        transforms_list=[transforms[sess_idx]],
        paths_list=[paths[sess_idx]],
        device='cpu',
        as_numpy=True,
        batch_load=False,
        rng_seed=hparams['rng_seed_data'])
    ims_orig = data_generator.datasets[sess_idx].data['images']
    del data_generator  # free up memory

    # get tt version number
    _, version = experiment_exists(hparams, which_version=True)
    print('producing syllable videos for arhmm %s' % version)
    # load latents/labels
    if hparams['model_class'].find('labels') > -1:
        from behavenet.data.utils import load_labels_like_latents
        latents = load_labels_like_latents(hparams, sess_ids, sess_idx)
    else:
        _, latents_file = get_transforms_paths('ae_latents', hparams,
                                               sess_ids[sess_idx])
        with open(latents_file, 'rb') as f:
            latents = pickle.load(f)
    trial_idxs = latents['trials'][dtype]
    # load model
    model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version,
                              'best_val_model.pt')
    with open(model_file, 'rb') as f:
        hmm = pickle.load(f)
    # infer discrete states
    states = [
        hmm.most_likely_states(latents['latents'][s])
        for s in latents['trials'][dtype]
    ]
    if len(states) == 0:
        raise ValueError('No latents for dtype=%s' % dtype)

    # find runs of discrete states; state indices is a list, each entry of which is a np array with
    # shape (n_state_instances, 3), where the 3 values are:
    # chunk_idx, chunk_start_idx, chunk_end_idx
    # chunk_idx is in [0, n_chunks], and indexes trial_idxs
    state_indices = get_discrete_chunks(states, include_edges=True)
    K = len(state_indices)

    # get all example over minimum state length threshold
    over_threshold_instances = [[] for _ in range(K)]
    for i_state in range(K):
        if state_indices[i_state].shape[0] > 0:
            state_lens = np.diff(state_indices[i_state][:, 1:3], axis=1)
            over_idxs = state_lens > min_threshold
            over_threshold_instances[i_state] = state_indices[i_state][
                over_idxs[:, 0]]
            np.random.shuffle(
                over_threshold_instances[i_state])  # shuffle instances

    make_syllable_movies(ims_orig=ims_orig,
                         state_list=over_threshold_instances,
                         trial_idxs=trial_idxs,
                         save_file=save_file,
                         max_frames=max_frames,
                         frame_rate=frame_rate,
                         n_buffer=n_buffer,
                         n_pre_frames=n_pre_frames,
                         n_rows=n_rows,
                         single_syllable=single_syllable)
コード例 #3
0
ファイル: decoder_utils.py プロジェクト: nihaarshah/behavenet
def plot_neural_reconstruction_traces_wrapper(hparams,
                                              save_file=None,
                                              trial=None,
                                              xtick_locs=None,
                                              frame_rate=None,
                                              format='png',
                                              **kwargs):
    """Plot ae latents and their neural reconstructions.

    This is a high-level function that loads the model described in the hparams dictionary and
    produces the necessary predicted latents.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify an ae latent decoder
    save_file : :obj:`str`
        full save file (path and filename)
    trial : :obj:`int`, optional
        if :obj:`NoneType`, use first test trial
    xtick_locs : :obj:`array-like`, optional
        tick locations in units of bins
    frame_rate : :obj:`float`, optional
        frame rate of behavorial video; to properly relabel xticks
    format : :obj:`str`, optional
        any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg'

    Returns
    -------
    :obj:`matplotlib.figure.Figure`
        matplotlib figure handle of plot

    """

    # find good trials
    import copy
    from behavenet.data.utils import get_transforms_paths
    from behavenet.data.data_generator import ConcatSessionsGenerator

    # ae data
    hparams_ae = copy.copy(hparams)
    hparams_ae['experiment_name'] = hparams['ae_experiment_name']
    hparams_ae['model_class'] = hparams['ae_model_class']
    hparams_ae['model_type'] = hparams['ae_model_type']

    ae_transform, ae_path = get_transforms_paths('ae_latents', hparams_ae,
                                                 None)

    # ae predictions data
    hparams_dec = copy.copy(hparams)
    hparams_dec['neural_ae_experiment_name'] = hparams[
        'decoder_experiment_name']
    hparams_dec['neural_ae_model_class'] = hparams['decoder_model_class']
    hparams_dec['neural_ae_model_type'] = hparams['decoder_model_type']
    ae_pred_transform, ae_pred_path = get_transforms_paths(
        'neural_ae_predictions', hparams_dec, None)

    signals = ['ae_latents', 'ae_predictions']
    transforms = [ae_transform, ae_pred_transform]
    paths = [ae_path, ae_pred_path]

    data_generator = ConcatSessionsGenerator(hparams['data_dir'], [hparams],
                                             signals_list=[signals],
                                             transforms_list=[transforms],
                                             paths_list=[paths],
                                             device='cpu',
                                             as_numpy=False,
                                             batch_load=True,
                                             rng_seed=0)

    if trial is None:
        # choose first test trial
        trial = data_generator.datasets[0].batch_idxs['test'][0]

    batch = data_generator.datasets[0][trial]
    traces_ae = batch['ae_latents'].cpu().detach().numpy()
    traces_neural = batch['ae_predictions'].cpu().detach().numpy()

    n_max_lags = hparams.get('n_max_lags',
                             0)  # only plot valid segment of data
    if n_max_lags > 0:
        fig = plot_neural_reconstruction_traces(
            traces_ae[n_max_lags:-n_max_lags],
            traces_neural[n_max_lags:-n_max_lags], save_file, xtick_locs,
            frame_rate, format, **kwargs)
    else:
        fig = plot_neural_reconstruction_traces(traces_ae, traces_neural,
                                                save_file, xtick_locs,
                                                frame_rate, format, **kwargs)
    return fig
コード例 #4
0
def get_model_latents_states(hparams,
                             version,
                             sess_idx=0,
                             return_samples=0,
                             cond_sampling=False,
                             dtype='test',
                             dtypes=['train', 'val', 'test'],
                             rng_seed=0):
    """Return arhmm defined in :obj:`hparams` with associated latents and states.

    Can also return sampled latents and states.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify an arhmm
    version : :obj:`str` or :obj:`int`
        test tube model version (can be 'best')
    sess_idx : :obj:`int`, optional
        session index into data generator
    return_samples : :obj:`int`, optional
        number of trials to sample from model
    cond_sampling : :obj:`bool`, optional
        if :obj:`True` return samples conditioned on most likely state sequence; else return
        unconditioned samples
    dtype : :obj:`str`, optional
        trial type to use for conditonal sampling; 'train' | 'val' | 'test'
    dtypes : :obj:`array-like`, optional
        trial types for which to collect latents and states
    rng_seed : :obj:`int`, optional
        random number generator seed to control sampling

    Returns
    -------
    :obj:`dict`
        - 'model' (:obj:`ssm.HMM` object)
        - 'latents' (:obj:`dict`): latents from train, val and test trials
        - 'states' (:obj:`dict`): states from train, val and test trials
        - 'trial_idxs' (:obj:`dict`): trial indices from train, val and test trials
        - 'latents_gen' (:obj:`list`)
        - 'states_gen' (:obj:`list`)

    """
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import experiment_exists
    from behavenet.fitting.utils import get_best_model_version
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)

    # get version/model
    if version == 'best':
        version = get_best_model_version(hparams['expt_dir'],
                                         measure='val_loss',
                                         best_def='max')[0]
    else:
        _, version = experiment_exists(hparams, which_version=True)
    if version is None:
        raise FileNotFoundError(
            'Could not find the specified model version in %s' %
            hparams['expt_dir'])

    # load model
    model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version,
                              'best_val_model.pt')
    with open(model_file, 'rb') as f:
        hmm = pickle.load(f)

    # load latents/labels
    if hparams['model_class'].find('labels') > -1:
        from behavenet.data.utils import load_labels_like_latents
        all_latents = load_labels_like_latents(hparams, sess_ids, sess_idx)
    else:
        _, latents_file = get_transforms_paths('ae_latents', hparams,
                                               sess_ids[sess_idx])
        with open(latents_file, 'rb') as f:
            all_latents = pickle.load(f)

    # collect inferred latents/states
    trial_idxs = {}
    latents = {}
    states = {}
    for data_type in dtypes:
        trial_idxs[data_type] = all_latents['trials'][data_type]
        latents[data_type] = [
            all_latents['latents'][i_trial]
            for i_trial in trial_idxs[data_type]
        ]
        states[data_type] = [
            hmm.most_likely_states(x) for x in latents[data_type]
        ]

    # collect sampled latents/states
    if return_samples > 0:
        states_gen = []
        np.random.seed(rng_seed)
        if cond_sampling:
            n_latents = latents[dtype][0].shape[1]
            latents_gen = [
                np.zeros((len(state_seg), n_latents))
                for state_seg in states[dtype]
            ]
            for i_seg, state_seg in enumerate(states[dtype]):
                for i_t in range(len(state_seg)):
                    if i_t >= 1:
                        latents_gen[i_seg][i_t] = hmm.observations.sample_x(
                            states[dtype][i_seg][i_t],
                            latents_gen[i_seg][:i_t],
                            input=np.zeros(0))
                    else:
                        latents_gen[i_seg][i_t] = hmm.observations.sample_x(
                            states[dtype][i_seg][i_t],
                            latents[dtype][i_seg][0].reshape((1, n_latents)),
                            input=np.zeros(0))
        else:
            latents_gen = []
            offset = 200
            for i in range(return_samples):
                these_states_gen, these_latents_gen = hmm.sample(
                    latents[dtype][0].shape[0] + offset)
                latents_gen.append(these_latents_gen[offset:])
                states_gen.append(these_states_gen[offset:])
    else:
        latents_gen = []
        states_gen = []

    return_dict = {
        'model': hmm,
        'latents': latents,
        'states': states,
        'trial_idxs': trial_idxs,
        'latents_gen': latents_gen,
        'states_gen': states_gen,
    }
    return return_dict
コード例 #5
0
def test_get_transforms_paths():

    hparams = {
        'data_dir': 'ddir',
        'results_dir': 'rdir',
        'lab': 'lab',
        'expt': 'expt',
        'animal': 'animal',
        'session': 'session'
    }
    session_dir = os.path.join(hparams['data_dir'], hparams['lab'],
                               hparams['expt'], hparams['animal'],
                               hparams['session'])
    hdf5_path = os.path.join(session_dir, 'data.hdf5')
    sess_id_str = str('%s_%s_%s_%s_' % (hparams['lab'], hparams['expt'],
                                        hparams['animal'], hparams['session']))

    # ------------------------
    # neural data
    # ------------------------
    # spikes, no thresholding
    hparams['neural_type'] = 'spikes'
    hparams['neural_thresh'] = 0
    transform, path = utils.get_transforms_paths('neural',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == hdf5_path
    assert transform is None

    # spikes, thresholding
    hparams['neural_type'] = 'spikes'
    hparams['neural_thresh'] = 1
    hparams['neural_bin_size'] = 1
    transform, path = utils.get_transforms_paths('neural',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == hdf5_path
    assert transform.__repr__().find('Threshold') > -1

    # calcium, no zscoring
    hparams['neural_type'] = 'ca'
    hparams['model_type'] = 'ae-neural'
    transform, path = utils.get_transforms_paths('neural',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == hdf5_path
    assert transform is None

    # calcium, zscoring
    hparams['neural_type'] = 'ca'
    hparams['model_type'] = 'neural-ae'
    transform, path = utils.get_transforms_paths('neural',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == hdf5_path
    assert transform.__repr__().find('ZScore') > -1

    # raise exception for incorrect neural type
    hparams['neural_type'] = 'wf'
    with pytest.raises(ValueError):
        utils.get_transforms_paths('neural',
                                   hparams,
                                   sess_id=None,
                                   check_splits=False)

    # TODO: test subsampling methods

    # ------------------------
    # ae latents
    # ------------------------
    hparams['session_dir'] = session_dir
    hparams['ae_model_class'] = 'ae'
    hparams['ae_model_type'] = 'conv'
    hparams['n_ae_latents'] = 8
    hparams['ae_experiment_name'] = 'tt_expt_ae'
    hparams['ae_version'] = 0

    ae_path = os.path.join(hparams['data_dir'], hparams['lab'],
                           hparams['expt'], hparams['animal'],
                           hparams['session'], hparams['ae_model_class'],
                           hparams['ae_model_type'],
                           '%02i_latents' % hparams['n_ae_latents'],
                           hparams['ae_experiment_name'])

    # user-defined latent path
    hparams['ae_latents_file'] = 'path/to/latents'
    transform, path = utils.get_transforms_paths('ae_latents',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == hparams['ae_latents_file']
    assert transform is None
    hparams.pop('ae_latents_file')

    # build pathname from hparams
    transform, path = utils.get_transforms_paths('ae_latents',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == os.path.join(ae_path, 'version_%i' % hparams['ae_version'],
                                '%slatents.pkl' % sess_id_str)
    assert transform is None

    # get correct transform
    transform, path = utils.get_transforms_paths('ae_latents_me',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == os.path.join(ae_path, 'version_%i' % hparams['ae_version'],
                                '%slatents.pkl' % sess_id_str)
    assert transform.__repr__().find('MotionEnergy') > -1

    # TODO: use get_best_model_version()

    # ------------------------
    # arhmm states
    # ------------------------
    hparams['n_ae_latents'] = 8
    hparams['n_arhmm_states'] = 2
    hparams['transitions'] = 'stationary'
    hparams['noise_type'] = 'gaussian'
    hparams['arhmm_experiment_name'] = 'tt_expt_arhmm'
    hparams['arhmm_version'] = 1

    arhmm_path = os.path.join(hparams['data_dir'], hparams['lab'],
                              hparams['expt'], hparams['animal'],
                              hparams['session'], 'arhmm',
                              '%02i_latents' % hparams['n_ae_latents'],
                              '%02i_states' % hparams['n_arhmm_states'],
                              hparams['transitions'], hparams['noise_type'],
                              hparams['arhmm_experiment_name'])

    # user-defined state path
    hparams['arhmm_states_file'] = 'path/to/states'
    transform, path = utils.get_transforms_paths('arhmm_states',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == hparams['arhmm_states_file']
    assert transform is None
    hparams.pop('arhmm_states_file')

    # build path name from hparams
    transform, path = utils.get_transforms_paths('arhmm_states',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == os.path.join(arhmm_path,
                                'version_%i' % hparams['arhmm_version'],
                                '%sstates.pkl' % sess_id_str)
    assert transform is None

    # include shuffle transform
    hparams['shuffle_rng_seed'] = 0
    transform, path = utils.get_transforms_paths('arhmm_states',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == os.path.join(arhmm_path,
                                'version_%i' % hparams['arhmm_version'],
                                '%sstates.pkl' % sess_id_str)
    assert transform.__repr__().find('BlockShuffle') > -1

    # TODO: use get_best_model_version()

    # ------------------------
    # neural ae predictions
    # ------------------------
    hparams['n_ae_latents'] = 8
    hparams['neural_ae_model_type'] = 'linear'
    hparams['neural_ae_experiment_name'] = 'tt_expt_ae_decoder'
    hparams['neural_ae_version'] = 2

    ae_pred_path = os.path.join(hparams['data_dir'], hparams['lab'],
                                hparams['expt'], hparams['animal'],
                                hparams['session'], 'neural-ae',
                                '%02i_latents' % hparams['n_ae_latents'],
                                hparams['neural_ae_model_type'], 'all',
                                hparams['neural_ae_experiment_name'])

    # user-defined predictions path
    hparams['ae_predictions_file'] = 'path/to/predictions'
    transform, path = utils.get_transforms_paths('neural_ae_predictions',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == hparams['ae_predictions_file']
    assert transform is None
    hparams.pop('ae_predictions_file')

    # build pathname from hparams
    transform, path = utils.get_transforms_paths('neural_ae_predictions',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == os.path.join(ae_pred_path,
                                'version_%i' % hparams['neural_ae_version'],
                                '%spredictions.pkl' % sess_id_str)
    assert transform is None

    # TODO: use get_best_model_version()

    # ------------------------
    # neural arhmm predictions
    # ------------------------
    hparams['n_ae_latents'] = 8
    hparams['n_arhmm_states'] = 10
    hparams['transitions'] = 'stationary'
    hparams['noise_type'] = 'studentst'
    hparams['neural_arhmm_model_type'] = 'linear'
    hparams['neural_arhmm_experiment_name'] = 'tt_expt_ae_decoder'
    hparams['neural_arhmm_version'] = 3

    arhmm_pred_path = os.path.join(hparams['data_dir'], hparams['lab'],
                                   hparams['expt'], hparams['animal'],
                                   hparams['session'], 'neural-arhmm',
                                   '%02i_latents' % hparams['n_ae_latents'],
                                   '%02i_states' % hparams['n_arhmm_states'],
                                   hparams['transitions'],
                                   hparams['neural_arhmm_model_type'], 'all',
                                   hparams['neural_arhmm_experiment_name'])

    # user-defined predictions path
    hparams['arhmm_predictions_file'] = 'path/to/predictions'
    transform, path = utils.get_transforms_paths('neural_arhmm_predictions',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == hparams['arhmm_predictions_file']
    assert transform is None
    hparams.pop('arhmm_predictions_file')

    # build pathname from hparams
    transform, path = utils.get_transforms_paths('neural_arhmm_predictions',
                                                 hparams,
                                                 sess_id=None,
                                                 check_splits=False)
    assert path == os.path.join(arhmm_pred_path,
                                'version_%i' % hparams['neural_arhmm_version'],
                                '%spredictions.pkl' % sess_id_str)
    assert transform is None

    # TODO: use get_best_model_version()

    # ------------------------
    # other
    # ------------------------
    with pytest.raises(ValueError):
        utils.get_transforms_paths('invalid',
                                   hparams,
                                   sess_id=None,
                                   check_splits=False)