示例#1
0
def check_model(config_dicts, dirs):
    hparams = {
        **config_dicts['data'], **config_dicts['model'], **config_dicts['training'],
        **config_dicts['compute']}
    hparams['save_dir'] = dirs.save_dir
    hparams['data_dir'] = dirs.data_dir
    # pick out single model if multiple were fit with test tube
    for key, val in hparams.items():
        if isinstance(val, list):
            hparams[key] = val[-1]
    exists = experiment_exists(hparams)
    if exists:
        result_str = BOLD + CGREEN + 'passed' + CEND
    else:
        result_str = BOLD + CRED + 'failed' + CEND
    return result_str
示例#2
0
def make_syllable_movies_wrapper(hparams,
                                 save_file,
                                 sess_idx=0,
                                 dtype='test',
                                 max_frames=400,
                                 frame_rate=10,
                                 min_threshold=0,
                                 n_buffer=5,
                                 n_pre_frames=3,
                                 n_rows=None,
                                 single_syllable=None):
    """Present video clips of each individual syllable in separate panels.

    This is a high-level function that loads the arhmm model described in the hparams dictionary
    and produces the necessary states/video frames.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify an arhmm
    save_file : :obj:`str`
        full save file (path and filename)
    sess_idx : :obj:`int`, optional
        session index into data generator
    dtype : :obj:`str`, optional
        types of trials to make video with; 'train' | 'val' | 'test'
    max_frames : :obj:`int`, optional
        maximum number of frames to animate
    frame_rate : :obj:`float`, optional
        frame rate of saved movie
    min_threshold : :obj:`int`, optional
        minimum number of frames in a syllable run to be considered for movie
    n_buffer : :obj:`int`
        number of blank frames between syllable instances
    n_pre_frames : :obj:`int`
        number of behavioral frames to precede each syllable instance
    n_rows : :obj:`int` or :obj:`NoneType`
        number of rows in output movie
    single_syllable : :obj:`int` or :obj:`NoneType`
        choose only a single state for movie

    """
    from behavenet.data.data_generator import ConcatSessionsGenerator
    from behavenet.data.utils import get_data_generator_inputs
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import experiment_exists
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    # load images, latents, and states
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)
    hparams['load_videos'] = True
    hparams, signals, transforms, paths = get_data_generator_inputs(
        hparams, sess_ids)
    data_generator = ConcatSessionsGenerator(
        hparams['data_dir'],
        sess_ids,
        signals_list=[signals[sess_idx]],
        transforms_list=[transforms[sess_idx]],
        paths_list=[paths[sess_idx]],
        device='cpu',
        as_numpy=True,
        batch_load=False,
        rng_seed=hparams['rng_seed_data'])
    ims_orig = data_generator.datasets[sess_idx].data['images']
    del data_generator  # free up memory

    # get tt version number
    _, version = experiment_exists(hparams, which_version=True)
    print('producing syllable videos for arhmm %s' % version)
    # load latents/labels
    if hparams['model_class'].find('labels') > -1:
        from behavenet.data.utils import load_labels_like_latents
        latents = load_labels_like_latents(hparams, sess_ids, sess_idx)
    else:
        _, latents_file = get_transforms_paths('ae_latents', hparams,
                                               sess_ids[sess_idx])
        with open(latents_file, 'rb') as f:
            latents = pickle.load(f)
    trial_idxs = latents['trials'][dtype]
    # load model
    model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version,
                              'best_val_model.pt')
    with open(model_file, 'rb') as f:
        hmm = pickle.load(f)
    # infer discrete states
    states = [
        hmm.most_likely_states(latents['latents'][s])
        for s in latents['trials'][dtype]
    ]
    if len(states) == 0:
        raise ValueError('No latents for dtype=%s' % dtype)

    # find runs of discrete states; state indices is a list, each entry of which is a np array with
    # shape (n_state_instances, 3), where the 3 values are:
    # chunk_idx, chunk_start_idx, chunk_end_idx
    # chunk_idx is in [0, n_chunks], and indexes trial_idxs
    state_indices = get_discrete_chunks(states, include_edges=True)
    K = len(state_indices)

    # get all example over minimum state length threshold
    over_threshold_instances = [[] for _ in range(K)]
    for i_state in range(K):
        if state_indices[i_state].shape[0] > 0:
            state_lens = np.diff(state_indices[i_state][:, 1:3], axis=1)
            over_idxs = state_lens > min_threshold
            over_threshold_instances[i_state] = state_indices[i_state][
                over_idxs[:, 0]]
            np.random.shuffle(
                over_threshold_instances[i_state])  # shuffle instances

    make_syllable_movies(ims_orig=ims_orig,
                         state_list=over_threshold_instances,
                         trial_idxs=trial_idxs,
                         save_file=save_file,
                         max_frames=max_frames,
                         frame_rate=frame_rate,
                         n_buffer=n_buffer,
                         n_pre_frames=n_pre_frames,
                         n_rows=n_rows,
                         single_syllable=single_syllable)
示例#3
0
def get_model_latents_states(hparams,
                             version,
                             sess_idx=0,
                             return_samples=0,
                             cond_sampling=False,
                             dtype='test',
                             dtypes=['train', 'val', 'test'],
                             rng_seed=0):
    """Return arhmm defined in :obj:`hparams` with associated latents and states.

    Can also return sampled latents and states.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify an arhmm
    version : :obj:`str` or :obj:`int`
        test tube model version (can be 'best')
    sess_idx : :obj:`int`, optional
        session index into data generator
    return_samples : :obj:`int`, optional
        number of trials to sample from model
    cond_sampling : :obj:`bool`, optional
        if :obj:`True` return samples conditioned on most likely state sequence; else return
        unconditioned samples
    dtype : :obj:`str`, optional
        trial type to use for conditonal sampling; 'train' | 'val' | 'test'
    dtypes : :obj:`array-like`, optional
        trial types for which to collect latents and states
    rng_seed : :obj:`int`, optional
        random number generator seed to control sampling

    Returns
    -------
    :obj:`dict`
        - 'model' (:obj:`ssm.HMM` object)
        - 'latents' (:obj:`dict`): latents from train, val and test trials
        - 'states' (:obj:`dict`): states from train, val and test trials
        - 'trial_idxs' (:obj:`dict`): trial indices from train, val and test trials
        - 'latents_gen' (:obj:`list`)
        - 'states_gen' (:obj:`list`)

    """
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import experiment_exists
    from behavenet.fitting.utils import get_best_model_version
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)

    # get version/model
    if version == 'best':
        version = get_best_model_version(hparams['expt_dir'],
                                         measure='val_loss',
                                         best_def='max')[0]
    else:
        _, version = experiment_exists(hparams, which_version=True)
    if version is None:
        raise FileNotFoundError(
            'Could not find the specified model version in %s' %
            hparams['expt_dir'])

    # load model
    model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version,
                              'best_val_model.pt')
    with open(model_file, 'rb') as f:
        hmm = pickle.load(f)

    # load latents/labels
    if hparams['model_class'].find('labels') > -1:
        from behavenet.data.utils import load_labels_like_latents
        all_latents = load_labels_like_latents(hparams, sess_ids, sess_idx)
    else:
        _, latents_file = get_transforms_paths('ae_latents', hparams,
                                               sess_ids[sess_idx])
        with open(latents_file, 'rb') as f:
            all_latents = pickle.load(f)

    # collect inferred latents/states
    trial_idxs = {}
    latents = {}
    states = {}
    for data_type in dtypes:
        trial_idxs[data_type] = all_latents['trials'][data_type]
        latents[data_type] = [
            all_latents['latents'][i_trial]
            for i_trial in trial_idxs[data_type]
        ]
        states[data_type] = [
            hmm.most_likely_states(x) for x in latents[data_type]
        ]

    # collect sampled latents/states
    if return_samples > 0:
        states_gen = []
        np.random.seed(rng_seed)
        if cond_sampling:
            n_latents = latents[dtype][0].shape[1]
            latents_gen = [
                np.zeros((len(state_seg), n_latents))
                for state_seg in states[dtype]
            ]
            for i_seg, state_seg in enumerate(states[dtype]):
                for i_t in range(len(state_seg)):
                    if i_t >= 1:
                        latents_gen[i_seg][i_t] = hmm.observations.sample_x(
                            states[dtype][i_seg][i_t],
                            latents_gen[i_seg][:i_t],
                            input=np.zeros(0))
                    else:
                        latents_gen[i_seg][i_t] = hmm.observations.sample_x(
                            states[dtype][i_seg][i_t],
                            latents[dtype][i_seg][0].reshape((1, n_latents)),
                            input=np.zeros(0))
        else:
            latents_gen = []
            offset = 200
            for i in range(return_samples):
                these_states_gen, these_latents_gen = hmm.sample(
                    latents[dtype][0].shape[0] + offset)
                latents_gen.append(these_latents_gen[offset:])
                states_gen.append(these_states_gen[offset:])
    else:
        latents_gen = []
        states_gen = []

    return_dict = {
        'model': hmm,
        'latents': latents,
        'states': states,
        'trial_idxs': trial_idxs,
        'latents_gen': latents_gen,
        'states_gen': states_gen,
    }
    return return_dict
示例#4
0
def load_metrics_csv_as_df(hparams,
                           lab,
                           expt,
                           metrics_list,
                           test=False,
                           version='best'):
    """Load metrics csv file and return as a pandas dataframe for easy plotting.

    Parameters
    ----------
    hparams : :obj:`dict`
        requires `sessions_csv`, `multisession`, `lab`, `expt`, `animal` and `session`
    lab : :obj:`str`
        for `get_lab_example`
    expt : :obj:`str`
        for `get_lab_example`
    metrics_list : :obj:`list`
        names of metrics to pull from csv; do not prepend with 'tr', 'val', or 'test'
    test : :obj:`bool`
        True to only return test values (computed once at end of training)
    version: :obj:`str`
        `best` to find best model in tt expt, None to find model with hyperparams defined in
        `hparams`, int to load specific model

    Returns
    -------
    :obj:`pandas.DataFrame` object

    """

    # programmatically fill out other hparams options
    get_lab_example(hparams, lab, expt)
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)

    # find metrics csv file
    if version is 'best':
        version = get_best_model_version(hparams['expt_dir'])[0]
    elif isinstance(version, int):
        version = version
    else:
        _, version = experiment_exists(hparams, which_version=True)
    version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % version)
    metric_file = os.path.join(version_dir, 'metrics.csv')
    metrics = pd.read_csv(metric_file)

    # collect data from csv file
    sess_ids = read_session_info_from_csv(
        os.path.join(version_dir, 'session_info.csv'))
    sess_ids_strs = []
    for sess_id in sess_ids:
        sess_ids_strs.append(
            str('%s/%s' % (sess_id['animal'], sess_id['session'])))
    metrics_df = []
    for i, row in metrics.iterrows():
        dataset = 'all' if row['dataset'] == -1 else sess_ids_strs[
            row['dataset']]
        if test:
            test_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'test'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **test_dict, 'loss': metric,
                            'val': row['test_%s' % metric]
                        },
                        index=[0]))
        else:
            # make dict for val data
            val_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'val'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **val_dict, 'loss': metric,
                            'val': row['val_%s' % metric]
                        },
                        index=[0]))
            # NOTE: grayed out lines are old version that returns a single dataframe row containing
            # all losses per epoch; new way creates one row per loss, making it easy to use with
            # seaborn's FacetGrid object for multi-axis plotting for metric in metrics_list:
            #     val_dict[metric] = row['val_%s' % metric]
            # metrics_df.append(pd.DataFrame(val_dict, index=[0]))
            # make dict for train data
            tr_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'train'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **tr_dict, 'loss': metric,
                            'val': row['tr_%s' % metric]
                        },
                        index=[0]))
            # for metric in metrics_list:
            #     tr_dict[metric] = row['tr_%s' % metric]
            # metrics_df.append(pd.DataFrame(tr_dict, index=[0]))
    return pd.concat(metrics_df, sort=True)