Ejemplo n.º 1
0
def build_data_generator(hparams, sess_ids, export_csv=True):
    """Helper function to build data generator from hparams dict.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain information specifying data inputs to model
    sess_ids : :obj:`list` of :obj:`dict`
        each entry is a session dict with keys 'lab', 'expt', 'animal', 'session'
    export_csv : :obj:`bool`, optional
        export csv file containing session info (useful when fitting multi-sessions)

    Returns
    -------
    :obj:`ConcatSessionsGenerator` object
        data generator

    """
    from behavenet.data.data_generator import ConcatSessionsGenerator
    from behavenet.data.utils import get_data_generator_inputs
    print('using data from following sessions:')
    for ids in sess_ids:
        print('%s' % os.path.join(
            hparams['save_dir'], ids['lab'], ids['expt'], ids['animal'], ids['session']))
    hparams, signals, transforms, paths = get_data_generator_inputs(hparams, sess_ids)
    if hparams.get('trial_splits', None) is not None:
        # assumes string of form 'train;val;test;gap'
        trs = [int(tr) for tr in hparams['trial_splits'].split(';')]
        trial_splits = {'train_tr': trs[0], 'val_tr': trs[1], 'test_tr': trs[2], 'gap_tr': trs[3]}
    else:
        trial_splits = None
    print('constructing data generator...', end='')
    data_generator = ConcatSessionsGenerator(
        hparams['data_dir'], sess_ids,
        signals_list=signals, transforms_list=transforms, paths_list=paths,
        device=hparams['device'], as_numpy=hparams['as_numpy'], batch_load=hparams['batch_load'],
        rng_seed=hparams['rng_seed_data'], trial_splits=trial_splits,
        train_frac=hparams['train_frac'])
    # csv order will reflect dataset order in data generator
    if export_csv:
        export_session_info_to_csv(os.path.join(
            hparams['expt_dir'], str('version_%i' % hparams['version'])), sess_ids)
    print('done')
    print(data_generator)
    return data_generator
Ejemplo n.º 2
0
def get_best_model_and_data(hparams, Model=None, load_data=True, version='best', data_kwargs=None):
    """Load the best model (and data) defined by hparams out of all available test-tube versions.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify both a model and the associated data
    Model : :obj:`behavenet.models` object, optional
        model type
    load_data : :obj:`bool`, optional
        if `False` then data generator is not returned
    version : :obj:`str` or :obj:`int`, optional
        can be 'best' to load best model
    data_kwargs : :obj:`dict`, optional
        additional kwargs for data generator

    Returns
    -------
    :obj:`tuple`
        - model (:obj:`behavenet.models` object)
        - data generator (:obj:`ConcatSessionsGenerator` object or :obj:`NoneType`)

    """

    import torch
    from behavenet.data.data_generator import ConcatSessionsGenerator
    from behavenet.data.utils import get_data_generator_inputs

    # get session_dir
    hparams['session_dir'], sess_ids = get_session_dir(
        hparams, session_source=hparams.get('all_source', 'save'))
    expt_dir = get_expt_dir(hparams)

    # get best model version
    if version == 'best':
        best_version_int = get_best_model_version(expt_dir)[0]
        best_version = str('version_{}'.format(best_version_int))
    elif version is None:
        # try to match hparams
        _, version_hp = experiment_exists(hparams, which_version=True)
        best_version = str('version_{}'.format(version_hp))
    else:
        if isinstance(version, str) and version[0] == 'v':
            # assume we got a string of the form 'version_{%i}'
            best_version = version
        else:
            best_version = str('version_{}'.format(version))
    # get int representation as well
    version_dir = os.path.join(expt_dir, best_version)
    arch_file = os.path.join(version_dir, 'meta_tags.pkl')
    model_file = os.path.join(version_dir, 'best_val_model.pt')
    if not os.path.exists(model_file) and not os.path.exists(model_file + '.meta'):
        model_file = os.path.join(version_dir, 'best_val_model.ckpt')
    print('Loading model defined in %s' % arch_file)

    with open(arch_file, 'rb') as f:
        hparams_new = pickle.load(f)

    # update paths if performing analysis on a different machine
    hparams_new['data_dir'] = hparams['data_dir']
    hparams_new['session_dir'] = hparams['session_dir']
    hparams_new['expt_dir'] = expt_dir
    hparams_new['use_output_mask'] = hparams.get('use_output_mask', False)
    hparams_new['use_label_mask'] = hparams.get('use_label_mask', False)
    hparams_new['device'] = hparams.get('device', 'cpu')

    # build data generator
    hparams_new, signals, transforms, paths = get_data_generator_inputs(hparams_new, sess_ids)
    if load_data:
        # sometimes we want a single data_generator for multiple models
        if data_kwargs is None:
            data_kwargs = {}
        data_generator = ConcatSessionsGenerator(
            hparams_new['data_dir'], sess_ids,
            signals_list=signals, transforms_list=transforms, paths_list=paths,
            device=hparams_new['device'], as_numpy=hparams_new['as_numpy'],
            batch_load=hparams_new['batch_load'], rng_seed=hparams_new['rng_seed_data'],
            train_frac=hparams_new['train_frac'], **data_kwargs)
    else:
        data_generator = None

    # build model
    if Model is None:
        if hparams['model_class'] == 'ae':
            from behavenet.models import AE as Model
        elif hparams['model_class'] == 'vae':
            from behavenet.models import VAE as Model
        elif hparams['model_class'] == 'cond-ae':
            from behavenet.models import ConditionalAE as Model
        elif hparams['model_class'] == 'cond-vae':
            from behavenet.models import ConditionalVAE as Model
        elif hparams['model_class'] == 'cond-ae-msp':
            from behavenet.models import AEMSP as Model
        elif hparams['model_class'] == 'beta-tcvae':
            from behavenet.models import BetaTCVAE as Model
        elif hparams['model_class'] == 'ps-vae':
            from behavenet.models import PSVAE as Model
        elif hparams['model_class'] == 'msps-vae':
            from behavenet.models import MSPSVAE as Model
        elif hparams['model_class'] == 'labels-images':
            from behavenet.models import ConvDecoder as Model
        elif hparams['model_class'] == 'neural-ae' or hparams['model_class'] == 'neural-ae-me' \
                or hparams['model_class'] == 'neural-arhmm' \
                or hparams['model_class'] == 'neural-labels':
            from behavenet.models import Decoder as Model
        elif hparams['model_class'] == 'ae-neural' or hparams['model_class'] == 'arhmm-neural' \
                or hparams['model_class'] == 'labels-neural':
            from behavenet.models import Decoder as Model
        elif hparams['model_class'] == 'arhmm':
            raise NotImplementedError('Cannot use get_best_model_and_data() for ssm models')
        else:
            raise NotImplementedError

    model = Model(hparams_new)
    model.version = int(best_version.split('_')[1])
    model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage))
    model.to(hparams_new['device'])
    model.eval()

    return model, data_generator
Ejemplo n.º 3
0
def make_syllable_movies_wrapper(hparams,
                                 save_file,
                                 sess_idx=0,
                                 dtype='test',
                                 max_frames=400,
                                 frame_rate=10,
                                 min_threshold=0,
                                 n_buffer=5,
                                 n_pre_frames=3,
                                 n_rows=None,
                                 single_syllable=None):
    """Present video clips of each individual syllable in separate panels.

    This is a high-level function that loads the arhmm model described in the hparams dictionary
    and produces the necessary states/video frames.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify an arhmm
    save_file : :obj:`str`
        full save file (path and filename)
    sess_idx : :obj:`int`, optional
        session index into data generator
    dtype : :obj:`str`, optional
        types of trials to make video with; 'train' | 'val' | 'test'
    max_frames : :obj:`int`, optional
        maximum number of frames to animate
    frame_rate : :obj:`float`, optional
        frame rate of saved movie
    min_threshold : :obj:`int`, optional
        minimum number of frames in a syllable run to be considered for movie
    n_buffer : :obj:`int`
        number of blank frames between syllable instances
    n_pre_frames : :obj:`int`
        number of behavioral frames to precede each syllable instance
    n_rows : :obj:`int` or :obj:`NoneType`
        number of rows in output movie
    single_syllable : :obj:`int` or :obj:`NoneType`
        choose only a single state for movie

    """
    from behavenet.data.data_generator import ConcatSessionsGenerator
    from behavenet.data.utils import get_data_generator_inputs
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import experiment_exists
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    # load images, latents, and states
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)
    hparams['load_videos'] = True
    hparams, signals, transforms, paths = get_data_generator_inputs(
        hparams, sess_ids)
    data_generator = ConcatSessionsGenerator(
        hparams['data_dir'],
        sess_ids,
        signals_list=[signals[sess_idx]],
        transforms_list=[transforms[sess_idx]],
        paths_list=[paths[sess_idx]],
        device='cpu',
        as_numpy=True,
        batch_load=False,
        rng_seed=hparams['rng_seed_data'])
    ims_orig = data_generator.datasets[sess_idx].data['images']
    del data_generator  # free up memory

    # get tt version number
    _, version = experiment_exists(hparams, which_version=True)
    print('producing syllable videos for arhmm %s' % version)
    # load latents/labels
    if hparams['model_class'].find('labels') > -1:
        from behavenet.data.utils import load_labels_like_latents
        latents = load_labels_like_latents(hparams, sess_ids, sess_idx)
    else:
        _, latents_file = get_transforms_paths('ae_latents', hparams,
                                               sess_ids[sess_idx])
        with open(latents_file, 'rb') as f:
            latents = pickle.load(f)
    trial_idxs = latents['trials'][dtype]
    # load model
    model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version,
                              'best_val_model.pt')
    with open(model_file, 'rb') as f:
        hmm = pickle.load(f)
    # infer discrete states
    states = [
        hmm.most_likely_states(latents['latents'][s])
        for s in latents['trials'][dtype]
    ]
    if len(states) == 0:
        raise ValueError('No latents for dtype=%s' % dtype)

    # find runs of discrete states; state indices is a list, each entry of which is a np array with
    # shape (n_state_instances, 3), where the 3 values are:
    # chunk_idx, chunk_start_idx, chunk_end_idx
    # chunk_idx is in [0, n_chunks], and indexes trial_idxs
    state_indices = get_discrete_chunks(states, include_edges=True)
    K = len(state_indices)

    # get all example over minimum state length threshold
    over_threshold_instances = [[] for _ in range(K)]
    for i_state in range(K):
        if state_indices[i_state].shape[0] > 0:
            state_lens = np.diff(state_indices[i_state][:, 1:3], axis=1)
            over_idxs = state_lens > min_threshold
            over_threshold_instances[i_state] = state_indices[i_state][
                over_idxs[:, 0]]
            np.random.shuffle(
                over_threshold_instances[i_state])  # shuffle instances

    make_syllable_movies(ims_orig=ims_orig,
                         state_list=over_threshold_instances,
                         trial_idxs=trial_idxs,
                         save_file=save_file,
                         max_frames=max_frames,
                         frame_rate=frame_rate,
                         n_buffer=n_buffer,
                         n_pre_frames=n_pre_frames,
                         n_rows=n_rows,
                         single_syllable=single_syllable)
Ejemplo n.º 4
0
def get_best_model_and_data(hparams, Model, load_data=True, version='best', data_kwargs=None):
    """Load the best model (and data) defined by hparams out of all available test-tube versions.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify both a model and the associated data
    Model : :obj:`behavenet.models` object
        model type
    load_data : :obj:`bool`, optional
    version : :obj:`str` or :obj:`int`, optional
        can be 'best' to load best model
    data_kwargs : :obj:`dict`, optional
        additional kwargs for data generator

    Returns
    -------
    :obj:`tuple`
        - model (:obj:`behavenet.models` object)
        - data generator (:obj:`ConcatSessionsGenerator` object or :obj:`NoneType`)

    """

    import torch
    from behavenet.data.data_generator import ConcatSessionsGenerator

    # get session_dir
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    expt_dir = get_expt_dir(hparams)

    # get best model version
    if version == 'best':
        best_version_int = get_best_model_version(expt_dir)[0]
        best_version = str('version_{}'.format(best_version_int))
    else:
        if isinstance(version, str) and version[0] == 'v':
            # assume we got a string of the form 'version_{%i}'
            best_version = version
        else:
            best_version = str('version_{}'.format(version))
    # get int representation as well
    version_dir = os.path.join(expt_dir, best_version)
    arch_file = os.path.join(version_dir, 'meta_tags.pkl')
    model_file = os.path.join(version_dir, 'best_val_model.pt')
    if not os.path.exists(model_file) and not os.path.exists(model_file + '.meta'):
        model_file = os.path.join(version_dir, 'best_val_model.ckpt')
    print('Loading model defined in %s' % arch_file)

    with open(arch_file, 'rb') as f:
        hparams_new = pickle.load(f)

    # update paths if performing analysis on a different machine
    hparams_new['data_dir'] = hparams['data_dir']
    hparams_new['session_dir'] = hparams['session_dir']
    hparams_new['expt_dir'] = expt_dir
    hparams_new['use_output_mask'] = hparams.get('use_output_mask', False)
    hparams_new['device'] = 'cpu'

    # build data generator
    hparams_new, signals, transforms, paths = get_data_generator_inputs(hparams_new, sess_ids)
    if load_data:
        # sometimes we want a single data_generator for multiple models
        if data_kwargs is None:
            data_kwargs = {}
        data_generator = ConcatSessionsGenerator(
            hparams_new['data_dir'], sess_ids,
            signals_list=signals, transforms_list=transforms, paths_list=paths,
            device=hparams_new['device'], as_numpy=hparams_new['as_numpy'],
            batch_load=hparams_new['batch_load'], rng_seed=hparams_new['rng_seed_data'],
            **data_kwargs)
    else:
        data_generator = None

    # build models
    model = Model(hparams_new)
    model.version = int(best_version.split('_')[1])
    model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage))
    model.to(hparams_new['device'])
    model.eval()

    return model, data_generator
Ejemplo n.º 5
0
def test_get_data_generator_inputs():

    hparams = {
        'data_dir': 'ddir',
        'results_dir': 'rdir',
        'lab': 'lab0',
        'expt': 'expt0',
        'animal': 'animal0',
        'session': 'session0'
    }
    session_dir = os.path.join(hparams['data_dir'], hparams['lab'],
                               hparams['expt'], hparams['animal'],
                               hparams['session'])
    hdf5_path = os.path.join(session_dir, 'data.hdf5')
    sess_ids = [{
        'lab': hparams['lab'],
        'expt': hparams['expt'],
        'animal': hparams['animal'],
        'session': hparams['session']
    }]

    # -----------------
    # ae
    # -----------------
    hparams['model_class'] = 'ae'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images']
    assert transforms[0] == [None]
    assert paths[0] == [hdf5_path]

    hparams['model_class'] = 'ae'
    hparams['use_output_mask'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'masks']
    assert transforms[0] == [None, None]
    assert paths[0] == [hdf5_path, hdf5_path]
    hparams['use_output_mask'] = False

    # -----------------
    # vae
    # -----------------
    hparams['model_class'] = 'vae'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images']
    assert transforms[0] == [None]
    assert paths[0] == [hdf5_path]

    hparams['model_class'] = 'vae'
    hparams['use_output_mask'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'masks']
    assert transforms[0] == [None, None]
    assert paths[0] == [hdf5_path, hdf5_path]
    hparams['use_output_mask'] = False

    # -----------------
    # beta-tcvae
    # -----------------
    hparams['model_class'] = 'beta-tcvae'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images']
    assert transforms[0] == [None]
    assert paths[0] == [hdf5_path]

    hparams['model_class'] = 'beta-tcvae'
    hparams['use_output_mask'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'masks']
    assert transforms[0] == [None, None]
    assert paths[0] == [hdf5_path, hdf5_path]
    hparams['use_output_mask'] = False

    # -----------------
    # ps-vae
    # -----------------
    hparams['model_class'] = 'ps-vae'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'labels']
    assert transforms[0] == [None, None]
    assert paths[0] == [hdf5_path, hdf5_path]

    hparams['model_class'] = 'ps-vae'
    hparams['use_output_mask'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'labels', 'masks']
    assert transforms[0] == [None, None, None]
    assert paths[0] == [hdf5_path, hdf5_path, hdf5_path]
    hparams['use_output_mask'] = False

    hparams['model_class'] = 'ps-vae'
    hparams['use_label_mask'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'labels', 'labels_masks']
    assert transforms[0] == [None, None, None]
    assert paths[0] == [hdf5_path, hdf5_path, hdf5_path]
    hparams['use_label_mask'] = False

    # -----------------
    # cond-vae
    # -----------------
    hparams['model_class'] = 'cond-vae'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'labels']
    assert transforms[0] == [None, None]
    assert paths[0] == [hdf5_path, hdf5_path]

    hparams['model_class'] = 'cond-vae'
    hparams['use_output_mask'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'labels', 'masks']
    assert transforms[0] == [None, None, None]
    assert paths[0] == [hdf5_path, hdf5_path, hdf5_path]
    hparams['use_output_mask'] = False

    # -----------------
    # cond-ae
    # -----------------
    hparams['model_class'] = 'cond-ae'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'labels']
    assert transforms[0] == [None, None]
    assert paths[0] == [hdf5_path, hdf5_path]

    hparams['model_class'] = 'cond-ae'
    hparams['use_output_mask'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'labels', 'masks']
    assert transforms[0] == [None, None, None]
    assert paths[0] == [hdf5_path, hdf5_path, hdf5_path]
    hparams['use_output_mask'] = False

    hparams['model_class'] = 'cond-ae'
    hparams['conditional_encoder'] = True
    hparams['y_pixels'] = 2
    hparams['x_pixels'] = 2
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'labels', 'labels_sc']
    assert transforms[0][0] is None
    assert transforms[0][1] is None
    assert transforms[0][2].__repr__().find('MakeOneHot2D') > -1
    assert paths[0] == [hdf5_path, hdf5_path, hdf5_path]
    hparams['conditional_encoder'] = False

    # -----------------
    # cond-ae-msp
    # -----------------
    hparams['model_class'] = 'cond-ae-msp'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'labels']
    assert transforms[0] == [None, None]
    assert paths[0] == [hdf5_path, hdf5_path]

    hparams['model_class'] = 'cond-ae-msp'
    hparams['use_label_mask'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'labels', 'labels_masks']
    assert transforms[0] == [None, None, None]
    assert paths[0] == [hdf5_path, hdf5_path, hdf5_path]
    hparams['use_label_mask'] = False

    # -----------------
    # ae_latents
    # -----------------
    hparams['model_class'] = 'ae_latents'
    hparams['session_dir'] = session_dir
    hparams['ae_model_class'] = 'ae'
    hparams['ae_model_type'] = 'conv'
    hparams['n_ae_latents'] = 8
    hparams['ae_experiment_name'] = 'tt_expt_ae'
    hparams['ae_version'] = 0
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['ae_latents']
    # transforms and paths tested by test_get_transforms_paths

    # -----------------
    # neural-ae
    # -----------------
    hparams['model_class'] = 'neural-ae'
    hparams['model_type'] = 'linear'
    hparams['session_dir'] = session_dir
    hparams['neural_type'] = 'spikes'
    hparams['neural_thresh'] = 0
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['neural', 'ae_latents']
    assert hparams_['input_signal'] == 'neural'
    assert hparams_['output_signal'] == 'ae_latents'
    assert hparams_['output_size'] == hparams['n_ae_latents']
    assert hparams_['noise_dist'] == 'gaussian'

    hparams['model_type'] = 'linear-mv'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert hparams_['noise_dist'] == 'gaussian-full'

    # -----------------
    # neural-ae-me
    # -----------------
    hparams['model_class'] = 'neural-ae-me'
    hparams['model_type'] = 'linear'
    hparams['session_dir'] = session_dir
    hparams['neural_type'] = 'spikes'
    hparams['neural_thresh'] = 0
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['neural', 'ae_latents']
    assert transforms[0][0] is None
    assert transforms[0][1].__repr__().find('MotionEnergy') > -1
    assert hparams_['input_signal'] == 'neural'
    assert hparams_['output_signal'] == 'ae_latents'
    assert hparams_['output_size'] == hparams['n_ae_latents']
    assert hparams_['noise_dist'] == 'gaussian'

    hparams['model_type'] = 'linear-mv'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert hparams_['noise_dist'] == 'gaussian-full'

    # -----------------
    # ae-neural
    # -----------------
    hparams['model_class'] = 'ae-neural'
    hparams['model_type'] = 'linear'
    hparams['session_dir'] = session_dir
    hparams['neural_type'] = 'spikes'
    hparams['neural_thresh'] = 0
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['neural', 'ae_latents']
    assert hparams_['input_signal'] == 'ae_latents'
    assert hparams_['output_signal'] == 'neural'
    assert hparams_['output_size'] is None
    assert hparams_['noise_dist'] == 'poisson'

    hparams['model_type'] = 'linear'
    hparams['neural_type'] = 'ca'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert hparams_['noise_dist'] == 'gaussian'

    hparams['model_type'] = 'linear-mv'
    hparams['neural_type'] = 'ca'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert hparams_['noise_dist'] == 'gaussian-full'

    # -----------------
    # neural-labels
    # -----------------
    hparams['model_class'] = 'neural-labels'
    hparams['model_type'] = 'linear'
    hparams['n_labels'] = 4
    hparams['session_dir'] = session_dir
    hparams['neural_type'] = 'spikes'
    hparams['neural_thresh'] = 0
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['neural', 'labels']
    assert hparams_['input_signal'] == 'neural'
    assert hparams_['output_signal'] == 'labels'
    assert hparams_['output_size'] == hparams['n_labels']
    assert hparams_['noise_dist'] == 'gaussian'

    hparams['model_type'] = 'linear-mv'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert hparams_['noise_dist'] == 'gaussian-full'

    # -----------------
    # labels-neural
    # -----------------
    hparams['model_class'] = 'labels-neural'
    hparams['model_type'] = 'linear'
    hparams['n_labels'] = 4
    hparams['session_dir'] = session_dir
    hparams['neural_type'] = 'spikes'
    hparams['neural_thresh'] = 0
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['neural', 'labels']
    assert hparams_['input_signal'] == 'labels'
    assert hparams_['output_signal'] == 'neural'
    assert hparams_['output_size'] is None
    assert hparams_['noise_dist'] == 'poisson'

    hparams['model_type'] = 'linear'
    hparams['neural_type'] = 'ca'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert hparams_['noise_dist'] == 'gaussian'

    hparams['model_type'] = 'linear-mv'
    hparams['neural_type'] = 'ca'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert hparams_['noise_dist'] == 'gaussian-full'

    # -----------------
    # arhmm
    # -----------------
    hparams['model_class'] = 'arhmm'
    hparams['session_dir'] = session_dir
    hparams['ae_model_type'] = 'conv'
    hparams['n_ae_latents'] = 8
    hparams['ae_experiment_name'] = 'tt_expt_ae'
    hparams['ae_version'] = 0
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['ae_latents']
    # transforms and paths tested by test_get_transforms_paths

    hparams['load_videos'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['ae_latents', 'images']
    hparams['load_videos'] = False

    hparams['use_output_mask'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['ae_latents', 'masks']
    hparams['use_output_mask'] = False

    # -----------------
    # arhmm-labels
    # -----------------
    hparams['model_class'] = 'arhmm-labels'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['labels']
    assert transforms[0] == [None]
    assert paths[0] == [hdf5_path]

    hparams['load_videos'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['labels', 'images']
    assert transforms[0] == [None, None]
    assert paths[0] == [hdf5_path, hdf5_path]
    hparams['load_videos'] = False

    hparams['use_output_mask'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['labels', 'masks']
    assert transforms[0] == [None, None]
    assert paths[0] == [hdf5_path, hdf5_path]
    hparams['use_output_mask'] = False

    # -----------------
    # neural-arhmm
    # -----------------
    hparams['model_class'] = 'neural-arhmm'
    hparams['model_type'] = 'linear'
    hparams['session_dir'] = session_dir
    hparams['neural_type'] = 'spikes'
    hparams['neural_thresh'] = 0
    hparams['n_arhmm_states'] = 2
    hparams['transitions'] = 'stationary'
    hparams['noise_type'] = 'gaussian'
    hparams['arhmm_experiment_name'] = 'tt_expt_arhmm'
    hparams['arhmm_version'] = 1
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['neural', 'arhmm_states']
    assert hparams_['input_signal'] == 'neural'
    assert hparams_['output_signal'] == 'arhmm_states'
    assert hparams_['output_size'] == hparams['n_arhmm_states']
    assert hparams_['noise_dist'] == 'categorical'

    # -----------------
    # arhmm-neural
    # -----------------
    hparams['model_class'] = 'arhmm-neural'
    hparams['model_type'] = 'linear'
    hparams['session_dir'] = session_dir
    hparams['neural_type'] = 'spikes'
    hparams['neural_thresh'] = 0
    hparams['n_arhmm_states'] = 2
    hparams['transitions'] = 'stationary'
    hparams['noise_type'] = 'gaussian'
    hparams['arhmm_experiment_name'] = 'tt_expt_arhmm'
    hparams['arhmm_version'] = 1
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['neural', 'arhmm_states']
    assert hparams_['input_signal'] == 'arhmm_states'
    assert hparams_['output_signal'] == 'neural'
    assert hparams_['output_size'] is None
    assert hparams_['noise_dist'] == 'poisson'

    hparams['model_type'] = 'linear'
    hparams['neural_type'] = 'ca'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert hparams_['noise_dist'] == 'gaussian'

    hparams['model_type'] = 'linear-mv'
    hparams['neural_type'] = 'ca'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert hparams_['noise_dist'] == 'gaussian-full'

    # -----------------
    # bayesian-decoding
    # -----------------
    hparams['model_class'] = 'bayesian-decoding'
    hparams['neural_ae_experiment_name'] = 'tt_expt_ae_decoder'
    hparams['neural_ae_model_type'] = 'linear'
    hparams['neural_ae_version'] = 0
    hparams['neural_arhmm_experiment_name'] = 'tt_expt_arhmm_decoder'
    hparams['neural_arhmm_model_type'] = 'linear'
    hparams['neural_arhmm_version'] = 0
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == [
        'ae_latents', 'ae_predictions', 'arhmm_predictions', 'arhmm_states'
    ]

    hparams['load_videos'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == [
        'ae_latents', 'ae_predictions', 'arhmm_predictions', 'arhmm_states',
        'images'
    ]
    hparams['load_videos'] = False

    hparams['use_output_mask'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == [
        'ae_latents', 'ae_predictions', 'arhmm_predictions', 'arhmm_states',
        'masks'
    ]
    hparams['use_output_mask'] = False

    # -----------------
    # labels-images
    # -----------------
    hparams['model_class'] = 'labels-images'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'labels']
    assert transforms[0] == [None, None]
    assert paths[0] == [hdf5_path, hdf5_path]
    assert hparams_['input_signal'] == 'labels'
    assert hparams_['output_signal'] == 'images'

    hparams['use_output_mask'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['images', 'labels', 'masks']
    hparams['use_output_mask'] = False

    # -----------------
    # labels
    # -----------------
    hparams['model_class'] = 'labels'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['labels']
    assert transforms[0] == [None]
    assert paths[0] == [hdf5_path]

    hparams['use_label_mask'] = True
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['labels', 'labels_masks']
    assert transforms[0] == [None, None]
    assert paths[0] == [hdf5_path, hdf5_path]
    hparams['use_label_mask'] = False

    # -----------------
    # labels_masks
    # -----------------
    hparams['model_class'] = 'labels_masks'
    hparams_, signals, transforms, paths = utils.get_data_generator_inputs(
        hparams, sess_ids, check_splits=False)
    assert signals[0] == ['labels_masks']
    assert transforms[0] == [None]
    assert paths[0] == [hdf5_path]

    # -----------------
    # other
    # -----------------
    hparams['model_class'] = 'test'
    with pytest.raises(ValueError):
        utils.get_data_generator_inputs(hparams, sess_ids, check_splits=False)