Exemplo n.º 1
0
def build_data_generator(hparams, sess_ids, export_csv=True):
    """Helper function to build data generator from hparams dict.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain information specifying data inputs to model
    sess_ids : :obj:`list` of :obj:`dict`
        each entry is a session dict with keys 'lab', 'expt', 'animal', 'session'
    export_csv : :obj:`bool`, optional
        export csv file containing session info (useful when fitting multi-sessions)

    Returns
    -------
    :obj:`ConcatSessionsGenerator` object
        data generator

    """
    from behavenet.data.data_generator import ConcatSessionsGenerator
    print('using data from following sessions:')
    for ids in sess_ids:
        print('%s' % os.path.join(hparams['save_dir'], ids['lab'], ids['expt'],
                                  ids['animal'], ids['session']))
    hparams, signals, transforms, paths = get_data_generator_inputs(
        hparams, sess_ids)
    if hparams.get('trial_splits', None) is not None:
        # assumes string of form 'train;val;test;gap'
        trs = [int(tr) for tr in hparams['trial_splits'].split(';')]
        trial_splits = {
            'train_tr': trs[0],
            'val_tr': trs[1],
            'test_tr': trs[2],
            'gap_tr': trs[3]
        }
    else:
        trial_splits = None
    print('constructing data generator...', end='')
    data_generator = ConcatSessionsGenerator(hparams['data_dir'],
                                             sess_ids,
                                             signals_list=signals,
                                             transforms_list=transforms,
                                             paths_list=paths,
                                             device=hparams['device'],
                                             as_numpy=hparams['as_numpy'],
                                             batch_load=hparams['batch_load'],
                                             rng_seed=hparams['rng_seed_data'],
                                             trial_splits=trial_splits,
                                             train_frac=hparams['train_frac'])
    # csv order will reflect dataset order in data generator
    if export_csv:
        export_session_info_to_csv(
            os.path.join(hparams['expt_dir'],
                         str('version_%i' % hparams['version'])), sess_ids)
    print('done')
    print(data_generator)
    return data_generator
Exemplo n.º 2
0
def get_best_model_and_data(hparams, Model=None, load_data=True, version='best', data_kwargs=None):
    """Load the best model (and data) defined by hparams out of all available test-tube versions.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify both a model and the associated data
    Model : :obj:`behavenet.models` object, optional
        model type
    load_data : :obj:`bool`, optional
        if `False` then data generator is not returned
    version : :obj:`str` or :obj:`int`, optional
        can be 'best' to load best model
    data_kwargs : :obj:`dict`, optional
        additional kwargs for data generator

    Returns
    -------
    :obj:`tuple`
        - model (:obj:`behavenet.models` object)
        - data generator (:obj:`ConcatSessionsGenerator` object or :obj:`NoneType`)

    """

    import torch
    from behavenet.data.data_generator import ConcatSessionsGenerator
    from behavenet.data.utils import get_data_generator_inputs

    # get session_dir
    hparams['session_dir'], sess_ids = get_session_dir(
        hparams, session_source=hparams.get('all_source', 'save'))
    expt_dir = get_expt_dir(hparams)

    # get best model version
    if version == 'best':
        best_version_int = get_best_model_version(expt_dir)[0]
        best_version = str('version_{}'.format(best_version_int))
    elif version is None:
        # try to match hparams
        _, version_hp = experiment_exists(hparams, which_version=True)
        best_version = str('version_{}'.format(version_hp))
    else:
        if isinstance(version, str) and version[0] == 'v':
            # assume we got a string of the form 'version_{%i}'
            best_version = version
        else:
            best_version = str('version_{}'.format(version))
    # get int representation as well
    version_dir = os.path.join(expt_dir, best_version)
    arch_file = os.path.join(version_dir, 'meta_tags.pkl')
    model_file = os.path.join(version_dir, 'best_val_model.pt')
    if not os.path.exists(model_file) and not os.path.exists(model_file + '.meta'):
        model_file = os.path.join(version_dir, 'best_val_model.ckpt')
    print('Loading model defined in %s' % arch_file)

    with open(arch_file, 'rb') as f:
        hparams_new = pickle.load(f)

    # update paths if performing analysis on a different machine
    hparams_new['data_dir'] = hparams['data_dir']
    hparams_new['session_dir'] = hparams['session_dir']
    hparams_new['expt_dir'] = expt_dir
    hparams_new['use_output_mask'] = hparams.get('use_output_mask', False)
    hparams_new['use_label_mask'] = hparams.get('use_label_mask', False)
    hparams_new['device'] = hparams.get('device', 'cpu')

    # build data generator
    hparams_new, signals, transforms, paths = get_data_generator_inputs(hparams_new, sess_ids)
    if load_data:
        # sometimes we want a single data_generator for multiple models
        if data_kwargs is None:
            data_kwargs = {}
        data_generator = ConcatSessionsGenerator(
            hparams_new['data_dir'], sess_ids,
            signals_list=signals, transforms_list=transforms, paths_list=paths,
            device=hparams_new['device'], as_numpy=hparams_new['as_numpy'],
            batch_load=hparams_new['batch_load'], rng_seed=hparams_new['rng_seed_data'],
            train_frac=hparams_new['train_frac'], **data_kwargs)
    else:
        data_generator = None

    # build model
    if Model is None:
        if hparams['model_class'] == 'ae':
            from behavenet.models import AE as Model
        elif hparams['model_class'] == 'vae':
            from behavenet.models import VAE as Model
        elif hparams['model_class'] == 'cond-ae':
            from behavenet.models import ConditionalAE as Model
        elif hparams['model_class'] == 'cond-vae':
            from behavenet.models import ConditionalVAE as Model
        elif hparams['model_class'] == 'cond-ae-msp':
            from behavenet.models import AEMSP as Model
        elif hparams['model_class'] == 'beta-tcvae':
            from behavenet.models import BetaTCVAE as Model
        elif hparams['model_class'] == 'ps-vae':
            from behavenet.models import PSVAE as Model
        elif hparams['model_class'] == 'msps-vae':
            from behavenet.models import MSPSVAE as Model
        elif hparams['model_class'] == 'labels-images':
            from behavenet.models import ConvDecoder as Model
        elif hparams['model_class'] == 'neural-ae' or hparams['model_class'] == 'neural-ae-me' \
                or hparams['model_class'] == 'neural-arhmm' \
                or hparams['model_class'] == 'neural-labels':
            from behavenet.models import Decoder as Model
        elif hparams['model_class'] == 'ae-neural' or hparams['model_class'] == 'arhmm-neural' \
                or hparams['model_class'] == 'labels-neural':
            from behavenet.models import Decoder as Model
        elif hparams['model_class'] == 'arhmm':
            raise NotImplementedError('Cannot use get_best_model_and_data() for ssm models')
        else:
            raise NotImplementedError

    model = Model(hparams_new)
    model.version = int(best_version.split('_')[1])
    model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage))
    model.to(hparams_new['device'])
    model.eval()

    return model, data_generator
Exemplo n.º 3
0
def make_syllable_movies_wrapper(hparams,
                                 save_file,
                                 sess_idx=0,
                                 dtype='test',
                                 max_frames=400,
                                 frame_rate=10,
                                 min_threshold=0,
                                 n_buffer=5,
                                 n_pre_frames=3,
                                 n_rows=None,
                                 single_syllable=None):
    """Present video clips of each individual syllable in separate panels.

    This is a high-level function that loads the arhmm model described in the hparams dictionary
    and produces the necessary states/video frames.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify an arhmm
    save_file : :obj:`str`
        full save file (path and filename)
    sess_idx : :obj:`int`, optional
        session index into data generator
    dtype : :obj:`str`, optional
        types of trials to make video with; 'train' | 'val' | 'test'
    max_frames : :obj:`int`, optional
        maximum number of frames to animate
    frame_rate : :obj:`float`, optional
        frame rate of saved movie
    min_threshold : :obj:`int`, optional
        minimum number of frames in a syllable run to be considered for movie
    n_buffer : :obj:`int`
        number of blank frames between syllable instances
    n_pre_frames : :obj:`int`
        number of behavioral frames to precede each syllable instance
    n_rows : :obj:`int` or :obj:`NoneType`
        number of rows in output movie
    single_syllable : :obj:`int` or :obj:`NoneType`
        choose only a single state for movie

    """
    from behavenet.data.data_generator import ConcatSessionsGenerator
    from behavenet.data.utils import get_data_generator_inputs
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import experiment_exists
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    # load images, latents, and states
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)
    hparams['load_videos'] = True
    hparams, signals, transforms, paths = get_data_generator_inputs(
        hparams, sess_ids)
    data_generator = ConcatSessionsGenerator(
        hparams['data_dir'],
        sess_ids,
        signals_list=[signals[sess_idx]],
        transforms_list=[transforms[sess_idx]],
        paths_list=[paths[sess_idx]],
        device='cpu',
        as_numpy=True,
        batch_load=False,
        rng_seed=hparams['rng_seed_data'])
    ims_orig = data_generator.datasets[sess_idx].data['images']
    del data_generator  # free up memory

    # get tt version number
    _, version = experiment_exists(hparams, which_version=True)
    print('producing syllable videos for arhmm %s' % version)
    # load latents/labels
    if hparams['model_class'].find('labels') > -1:
        from behavenet.data.utils import load_labels_like_latents
        latents = load_labels_like_latents(hparams, sess_ids, sess_idx)
    else:
        _, latents_file = get_transforms_paths('ae_latents', hparams,
                                               sess_ids[sess_idx])
        with open(latents_file, 'rb') as f:
            latents = pickle.load(f)
    trial_idxs = latents['trials'][dtype]
    # load model
    model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version,
                              'best_val_model.pt')
    with open(model_file, 'rb') as f:
        hmm = pickle.load(f)
    # infer discrete states
    states = [
        hmm.most_likely_states(latents['latents'][s])
        for s in latents['trials'][dtype]
    ]
    if len(states) == 0:
        raise ValueError('No latents for dtype=%s' % dtype)

    # find runs of discrete states; state indices is a list, each entry of which is a np array with
    # shape (n_state_instances, 3), where the 3 values are:
    # chunk_idx, chunk_start_idx, chunk_end_idx
    # chunk_idx is in [0, n_chunks], and indexes trial_idxs
    state_indices = get_discrete_chunks(states, include_edges=True)
    K = len(state_indices)

    # get all example over minimum state length threshold
    over_threshold_instances = [[] for _ in range(K)]
    for i_state in range(K):
        if state_indices[i_state].shape[0] > 0:
            state_lens = np.diff(state_indices[i_state][:, 1:3], axis=1)
            over_idxs = state_lens > min_threshold
            over_threshold_instances[i_state] = state_indices[i_state][
                over_idxs[:, 0]]
            np.random.shuffle(
                over_threshold_instances[i_state])  # shuffle instances

    make_syllable_movies(ims_orig=ims_orig,
                         state_list=over_threshold_instances,
                         trial_idxs=trial_idxs,
                         save_file=save_file,
                         max_frames=max_frames,
                         frame_rate=frame_rate,
                         n_buffer=n_buffer,
                         n_pre_frames=n_pre_frames,
                         n_rows=n_rows,
                         single_syllable=single_syllable)
Exemplo n.º 4
0
def plot_neural_reconstruction_traces_wrapper(hparams,
                                              save_file=None,
                                              trial=None,
                                              xtick_locs=None,
                                              frame_rate=None,
                                              format='png',
                                              **kwargs):
    """Plot ae latents and their neural reconstructions.

    This is a high-level function that loads the model described in the hparams dictionary and
    produces the necessary predicted latents.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify an ae latent decoder
    save_file : :obj:`str`
        full save file (path and filename)
    trial : :obj:`int`, optional
        if :obj:`NoneType`, use first test trial
    xtick_locs : :obj:`array-like`, optional
        tick locations in units of bins
    frame_rate : :obj:`float`, optional
        frame rate of behavorial video; to properly relabel xticks
    format : :obj:`str`, optional
        any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg'

    Returns
    -------
    :obj:`matplotlib.figure.Figure`
        matplotlib figure handle of plot

    """

    # find good trials
    import copy
    from behavenet.data.utils import get_transforms_paths
    from behavenet.data.data_generator import ConcatSessionsGenerator

    # ae data
    hparams_ae = copy.copy(hparams)
    hparams_ae['experiment_name'] = hparams['ae_experiment_name']
    hparams_ae['model_class'] = hparams['ae_model_class']
    hparams_ae['model_type'] = hparams['ae_model_type']

    ae_transform, ae_path = get_transforms_paths('ae_latents', hparams_ae,
                                                 None)

    # ae predictions data
    hparams_dec = copy.copy(hparams)
    hparams_dec['neural_ae_experiment_name'] = hparams[
        'decoder_experiment_name']
    hparams_dec['neural_ae_model_class'] = hparams['decoder_model_class']
    hparams_dec['neural_ae_model_type'] = hparams['decoder_model_type']
    ae_pred_transform, ae_pred_path = get_transforms_paths(
        'neural_ae_predictions', hparams_dec, None)

    signals = ['ae_latents', 'ae_predictions']
    transforms = [ae_transform, ae_pred_transform]
    paths = [ae_path, ae_pred_path]

    data_generator = ConcatSessionsGenerator(hparams['data_dir'], [hparams],
                                             signals_list=[signals],
                                             transforms_list=[transforms],
                                             paths_list=[paths],
                                             device='cpu',
                                             as_numpy=False,
                                             batch_load=True,
                                             rng_seed=0)

    if trial is None:
        # choose first test trial
        trial = data_generator.datasets[0].batch_idxs['test'][0]

    batch = data_generator.datasets[0][trial]
    traces_ae = batch['ae_latents'].cpu().detach().numpy()
    traces_neural = batch['ae_predictions'].cpu().detach().numpy()

    n_max_lags = hparams.get('n_max_lags',
                             0)  # only plot valid segment of data
    if n_max_lags > 0:
        fig = plot_neural_reconstruction_traces(
            traces_ae[n_max_lags:-n_max_lags],
            traces_neural[n_max_lags:-n_max_lags], save_file, xtick_locs,
            frame_rate, format, **kwargs)
    else:
        fig = plot_neural_reconstruction_traces(traces_ae, traces_neural,
                                                save_file, xtick_locs,
                                                frame_rate, format, **kwargs)
    return fig
Exemplo n.º 5
0
def get_best_model_and_data(hparams, Model, load_data=True, version='best', data_kwargs=None):
    """Load the best model (and data) defined by hparams out of all available test-tube versions.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify both a model and the associated data
    Model : :obj:`behavenet.models` object
        model type
    load_data : :obj:`bool`, optional
    version : :obj:`str` or :obj:`int`, optional
        can be 'best' to load best model
    data_kwargs : :obj:`dict`, optional
        additional kwargs for data generator

    Returns
    -------
    :obj:`tuple`
        - model (:obj:`behavenet.models` object)
        - data generator (:obj:`ConcatSessionsGenerator` object or :obj:`NoneType`)

    """

    import torch
    from behavenet.data.data_generator import ConcatSessionsGenerator

    # get session_dir
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    expt_dir = get_expt_dir(hparams)

    # get best model version
    if version == 'best':
        best_version_int = get_best_model_version(expt_dir)[0]
        best_version = str('version_{}'.format(best_version_int))
    else:
        if isinstance(version, str) and version[0] == 'v':
            # assume we got a string of the form 'version_{%i}'
            best_version = version
        else:
            best_version = str('version_{}'.format(version))
    # get int representation as well
    version_dir = os.path.join(expt_dir, best_version)
    arch_file = os.path.join(version_dir, 'meta_tags.pkl')
    model_file = os.path.join(version_dir, 'best_val_model.pt')
    if not os.path.exists(model_file) and not os.path.exists(model_file + '.meta'):
        model_file = os.path.join(version_dir, 'best_val_model.ckpt')
    print('Loading model defined in %s' % arch_file)

    with open(arch_file, 'rb') as f:
        hparams_new = pickle.load(f)

    # update paths if performing analysis on a different machine
    hparams_new['data_dir'] = hparams['data_dir']
    hparams_new['session_dir'] = hparams['session_dir']
    hparams_new['expt_dir'] = expt_dir
    hparams_new['use_output_mask'] = hparams.get('use_output_mask', False)
    hparams_new['device'] = 'cpu'

    # build data generator
    hparams_new, signals, transforms, paths = get_data_generator_inputs(hparams_new, sess_ids)
    if load_data:
        # sometimes we want a single data_generator for multiple models
        if data_kwargs is None:
            data_kwargs = {}
        data_generator = ConcatSessionsGenerator(
            hparams_new['data_dir'], sess_ids,
            signals_list=signals, transforms_list=transforms, paths_list=paths,
            device=hparams_new['device'], as_numpy=hparams_new['as_numpy'],
            batch_load=hparams_new['batch_load'], rng_seed=hparams_new['rng_seed_data'],
            **data_kwargs)
    else:
        data_generator = None

    # build models
    model = Model(hparams_new)
    model.version = int(best_version.split('_')[1])
    model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage))
    model.to(hparams_new['device'])
    model.eval()

    return model, data_generator