Example #1
0
def get_model_latents_states(hparams,
                             version,
                             sess_idx=0,
                             return_samples=0,
                             cond_sampling=False,
                             dtype='test',
                             dtypes=['train', 'val', 'test'],
                             rng_seed=0):
    """Return arhmm defined in :obj:`hparams` with associated latents and states.

    Can also return sampled latents and states.

    Parameters
    ----------
    hparams : :obj:`dict`
        needs to contain enough information to specify an arhmm
    version : :obj:`str` or :obj:`int`
        test tube model version (can be 'best')
    sess_idx : :obj:`int`, optional
        session index into data generator
    return_samples : :obj:`int`, optional
        number of trials to sample from model
    cond_sampling : :obj:`bool`, optional
        if :obj:`True` return samples conditioned on most likely state sequence; else return
        unconditioned samples
    dtype : :obj:`str`, optional
        trial type to use for conditonal sampling; 'train' | 'val' | 'test'
    dtypes : :obj:`array-like`, optional
        trial types for which to collect latents and states
    rng_seed : :obj:`int`, optional
        random number generator seed to control sampling

    Returns
    -------
    :obj:`dict`
        - 'model' (:obj:`ssm.HMM` object)
        - 'latents' (:obj:`dict`): latents from train, val and test trials
        - 'states' (:obj:`dict`): states from train, val and test trials
        - 'trial_idxs' (:obj:`dict`): trial indices from train, val and test trials
        - 'latents_gen' (:obj:`list`)
        - 'states_gen' (:obj:`list`)

    """
    from behavenet.data.utils import get_transforms_paths
    from behavenet.fitting.utils import experiment_exists
    from behavenet.fitting.utils import get_best_model_version
    from behavenet.fitting.utils import get_expt_dir
    from behavenet.fitting.utils import get_session_dir

    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)

    # get version/model
    if version == 'best':
        version = get_best_model_version(hparams['expt_dir'],
                                         measure='val_loss',
                                         best_def='max')[0]
    else:
        _, version = experiment_exists(hparams, which_version=True)
    if version is None:
        raise FileNotFoundError(
            'Could not find the specified model version in %s' %
            hparams['expt_dir'])

    # load model
    model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version,
                              'best_val_model.pt')
    with open(model_file, 'rb') as f:
        hmm = pickle.load(f)

    # load latents/labels
    if hparams['model_class'].find('labels') > -1:
        from behavenet.data.utils import load_labels_like_latents
        all_latents = load_labels_like_latents(hparams, sess_ids, sess_idx)
    else:
        _, latents_file = get_transforms_paths('ae_latents', hparams,
                                               sess_ids[sess_idx])
        with open(latents_file, 'rb') as f:
            all_latents = pickle.load(f)

    # collect inferred latents/states
    trial_idxs = {}
    latents = {}
    states = {}
    for data_type in dtypes:
        trial_idxs[data_type] = all_latents['trials'][data_type]
        latents[data_type] = [
            all_latents['latents'][i_trial]
            for i_trial in trial_idxs[data_type]
        ]
        states[data_type] = [
            hmm.most_likely_states(x) for x in latents[data_type]
        ]

    # collect sampled latents/states
    if return_samples > 0:
        states_gen = []
        np.random.seed(rng_seed)
        if cond_sampling:
            n_latents = latents[dtype][0].shape[1]
            latents_gen = [
                np.zeros((len(state_seg), n_latents))
                for state_seg in states[dtype]
            ]
            for i_seg, state_seg in enumerate(states[dtype]):
                for i_t in range(len(state_seg)):
                    if i_t >= 1:
                        latents_gen[i_seg][i_t] = hmm.observations.sample_x(
                            states[dtype][i_seg][i_t],
                            latents_gen[i_seg][:i_t],
                            input=np.zeros(0))
                    else:
                        latents_gen[i_seg][i_t] = hmm.observations.sample_x(
                            states[dtype][i_seg][i_t],
                            latents[dtype][i_seg][0].reshape((1, n_latents)),
                            input=np.zeros(0))
        else:
            latents_gen = []
            offset = 200
            for i in range(return_samples):
                these_states_gen, these_latents_gen = hmm.sample(
                    latents[dtype][0].shape[0] + offset)
                latents_gen.append(these_latents_gen[offset:])
                states_gen.append(these_states_gen[offset:])
    else:
        latents_gen = []
        states_gen = []

    return_dict = {
        'model': hmm,
        'latents': latents,
        'states': states,
        'trial_idxs': trial_idxs,
        'latents_gen': latents_gen,
        'states_gen': states_gen,
    }
    return return_dict
Example #2
0
def get_transforms_paths(data_type, hparams, sess_id, check_splits=True):
    """Helper function for generating session-specific transforms and paths.

    Parameters
    ----------
    data_type : :obj:`str`
        'neural' | 'ae_latents' | 'arhmm_states' | 'neural_ae_predictions' |
        'neural_arhmm_predictions'
    hparams : :obj:`dict`
        - required keys for :obj:`data_type=neural`: 'neural_type', 'neural_thresh'
        - required keys for :obj:`data_type=ae_latents`: 'ae_experiment_name', 'ae_model_type',
          'n_ae_latents', 'ae_version' or 'ae_latents_file'; this last option defines either the
          specific ae version (as 'best' or an int) or a path to a specific ae latents pickle file.
        - required keys for :obj:`data_type=arhmm_states`: 'arhmm_experiment_name',
          'n_arhmm_states', 'kappa', 'noise_type', 'n_ae_latents', 'arhmm_version' or
          'arhmm_states_file'; this last option defines either the specific arhmm version (as
          'best' or an int) or a path to a specific ae latents pickle file.
        - required keys for :obj:`data_type=neural_ae_predictions`: 'neural_ae_experiment_name',
          'neural_ae_model_type', 'neural_ae_version' or 'ae_predictions_file' plus keys for neural
          and ae_latents data types.
        - required keys for :obj:`data_type=neural_arhmm_predictions`:
          'neural_arhmm_experiment_name', 'neural_arhmm_model_type', 'neural_arhmm_version' or
          'arhmm_predictions_file', plus keys for neural and arhmm_states data types.
    sess_id : :obj:`dict`
        each list entry is a session-specific dict with keys 'lab', 'expt', 'animal', 'session'
    check_splits : :obj:`bool`, optional
        check data splits and data rng seed between hparams and loaded model outputs (e.g. latents)

    Returns
    -------
    :obj:`tuple`
        - transform (:obj:`behavenet.data.transforms.Transform` object): session-specific transform
        - path (:obj:`str`): session-specific path

    """

    from behavenet.data.transforms import BlockShuffle
    from behavenet.data.transforms import Compose
    from behavenet.data.transforms import MotionEnergy
    from behavenet.data.transforms import SelectIdxs
    from behavenet.data.transforms import Threshold
    from behavenet.data.transforms import ZScore
    from behavenet.fitting.utils import get_best_model_version
    from behavenet.fitting.utils import get_expt_dir

    # check for multisession by comparing hparams and sess_id
    hparams_ = {key: hparams[key] for key in ['lab', 'expt', 'animal', 'session']}
    if sess_id is None:
        sess_id = hparams_

    sess_id_str = str('%s_%s_%s_%s_' % (
        sess_id['lab'], sess_id['expt'], sess_id['animal'], sess_id['session']))

    if data_type == 'neural':

        check_splits = False

        path = os.path.join(
            hparams['data_dir'], sess_id['lab'], sess_id['expt'], sess_id['animal'],
            sess_id['session'], 'data.hdf5')

        transforms_ = []

        # filter neural data by indices (regions, cell types, etc)
        if hparams.get('subsample_method', 'none') != 'none':
            # get indices
            sampling = hparams['subsample_method']
            idxs_name = hparams['subsample_idxs_name']
            idxs_dict = get_region_list(hparams)
            if sampling == 'single':
                idxs = idxs_dict[idxs_name]
            elif sampling == 'loo':
                idxs = []
                for idxs_key, idxs_val in idxs_dict.items():
                    if idxs_key != idxs_name:
                        idxs.append(idxs_val)
                idxs = np.concatenate(idxs)
            else:
                raise ValueError('"%s" is an invalid index sampling option' % sampling)
            transforms_.append(SelectIdxs(idxs, str('%s-%s' % (idxs_name, sampling))))

        # filter neural data by activity
        if hparams['neural_type'] == 'spikes':
            if hparams['neural_thresh'] > 0:
                transforms_.append(Threshold(
                    threshold=hparams['neural_thresh'],
                    bin_size=hparams['neural_bin_size']))
        elif hparams['neural_type'] == 'ca':
            if hparams['model_type'][-6:] != 'neural':
                # don't zscore if predicting calcium activity
                transforms_.append(ZScore())
        elif hparams['neural_type'] == 'ca-zscored':
            pass
        else:
            raise ValueError('"%s" is an invalid neural type' % hparams['neural_type'])

        # compose filters
        if len(transforms_) == 0:
            transform = None
        else:
            transform = Compose(transforms_)

    elif data_type == 'ae_latents' or data_type == 'latents' \
            or data_type == 'ae_latents_me' or data_type == 'latents_me':

        if data_type == 'ae_latents_me' or data_type == 'latents_me':
            transform = MotionEnergy()
        else:
            transform = None

        if 'ae_latents_file' in hparams:
            path = hparams['ae_latents_file']
        else:
            ae_dir = get_expt_dir(
                hparams, model_class=hparams['ae_model_class'],
                expt_name=hparams['ae_experiment_name'],
                model_type=hparams['ae_model_type'])
            if 'ae_version' in hparams and hparams['ae_version'] != 'best':
                # json args read as strings
                if isinstance(hparams['ae_version'], str):
                    hparams['ae_version'] = int(hparams['ae_version'])
                ae_version = str('version_%i' % hparams['ae_version'])
            else:
                ae_version = 'version_%i' % get_best_model_version(ae_dir, 'val_loss')[0]
            ae_latents = str('%slatents.pkl' % sess_id_str)
            path = os.path.join(ae_dir, ae_version, ae_latents)

    elif data_type == 'arhmm_states' or data_type == 'states':

        if hparams.get('shuffle_rng_seed') is not None:
            transform = BlockShuffle(hparams['shuffle_rng_seed'])
        else:
            transform = None

        if 'arhmm_states_file' in hparams:
            path = hparams['arhmm_states_file']
        else:
            arhmm_dir = get_expt_dir(
                hparams, model_class='arhmm', expt_name=hparams['arhmm_experiment_name'])
            if 'arhmm_version' in hparams and isinstance(hparams['arhmm_version'], int):
                arhmm_version = str('version_%i' % hparams['arhmm_version'])
            else:
                arhmm_version = 'version_%i' % get_best_model_version(
                    arhmm_dir, 'val_loss', best_def='min')[0]
            arhmm_states = str('%sstates.pkl' % sess_id_str)
            path = os.path.join(arhmm_dir, arhmm_version, arhmm_states)

    elif data_type == 'neural_ae_predictions' or data_type == 'ae_predictions':

        transform = None

        if 'ae_predictions_file' in hparams:
            path = hparams['ae_predictions_file']
        else:
            neural_ae_dir = get_expt_dir(
                hparams, model_class='neural-ae',
                expt_name=hparams['neural_ae_experiment_name'],
                model_type=hparams['neural_ae_model_type'])
            if 'neural_ae_version' in hparams and isinstance(hparams['neural_ae_version'], int):
                neural_ae_version = str('version_%i' % hparams['neural_ae_version'])
            else:
                neural_ae_version = 'version_%i' % get_best_model_version(
                    neural_ae_dir, 'val_loss')[0]
            neural_ae_predictions = str('%spredictions.pkl' % sess_id_str)
            path = os.path.join(neural_ae_dir, neural_ae_version, neural_ae_predictions)

    elif data_type == 'neural_arhmm_predictions' or data_type == 'arhmm_predictions':

        transform = None

        if 'arhmm_predictions_file' in hparams:
            path = hparams['arhmm_predictions_file']
        else:
            neural_arhmm_dir = get_expt_dir(
                hparams, model_class='neural-arhmm',
                expt_name=hparams['neural_arhmm_experiment_name'],
                model_type=hparams['neural_arhmm_model_type'])
            if 'neural_arhmm_version' in hparams and \
                    isinstance(hparams['neural_arhmm_version'], int):
                neural_arhmm_version = str('version_%i' % hparams['neural_arhmm_version'])
            else:
                neural_arhmm_version = 'version_%i' % get_best_model_version(
                    neural_arhmm_dir, 'val_loss')[0]
            neural_arhmm_predictions = str('%spredictions.pkl' % sess_id_str)
            path = os.path.join(neural_arhmm_dir, neural_arhmm_version, neural_arhmm_predictions)

    else:
        raise ValueError('"%s" is an invalid data_type' % data_type)

    # check training data split is the same
    if check_splits:
        check_same_training_split(path, hparams)

    return transform, path
Example #3
0
def get_transforms_paths(data_type, hparams, sess_id):
    """Helper function for generating session-specific transforms and paths.

    Parameters
    ----------
    data_type : :obj:`str`
        'neural' | 'ae_latents' | 'arhmm_states' | 'neural_ae_predictions' |
        'neural_arhmm_predictions'
    hparams : :obj:`dict`
        - required keys for :obj:`data_type=neural`: 'neural_type', 'neural_thresh'
        - required keys for :obj:`data_type=ae_latents`: 'ae_experiment_name', 'ae_model_type', 'n_ae_latents', 'ae_version' or 'ae_latents_file'; this last option defines either the specific ae version (as 'best' or an int) or a path to a specific ae latents pickle file.
        - required keys for :obj:`data_type=arhmm_states`: 'arhmm_experiment_name', 'n_arhmm_states', 'kappa', 'noise_type', 'n_ae_latents', 'arhmm_version' or 'arhmm_states_file'; this last option defines either the specific arhmm version (as 'best' or an int) or a path to a specific ae latents pickle file.
        - required keys for :obj:`data_type=neural_ae_predictions`: 'neural_ae_experiment_name', 'neural_ae_model_type', 'neural_ae_version' or 'ae_predictions_file' plus keys for neural and ae_latents data types.
        - required keys for :obj:`data_type=neural_arhmm_predictions`: 'neural_arhmm_experiment_name', 'neural_arhmm_model_type', 'neural_arhmm_version' or 'arhmm_predictions_file', plus keys for neural and arhmm_states data types.
    sess_id : :obj:`dict`
        each list entry is a session-specific dict with keys 'lab', 'expt', 'animal', 'session'

    Returns
    -------
    :obj:`tuple`
        - hparams (:obj:`dict`): updated with model-specific information like input and output size
        - signals (:obj:`list`): session-specific signals
        - transforms (:obj:`list`): session-specific transforms
        - paths (:obj:`list`): session-specific paths

    """

    from behavenet.data.transforms import SelectIdxs
    from behavenet.data.transforms import Threshold
    from behavenet.data.transforms import ZScore
    from behavenet.data.transforms import BlockShuffle
    from behavenet.data.transforms import Compose
    from behavenet.fitting.utils import get_best_model_version
    from behavenet.fitting.utils import get_expt_dir

    # check for multisession by comparing hparams and sess_id
    hparams_ = {
        key: hparams[key]
        for key in ['lab', 'expt', 'animal', 'session']
    }
    if sess_id is None:
        sess_id = hparams_

    sess_id_str = str('%s_%s_%s_%s_' % (sess_id['lab'], sess_id['expt'],
                                        sess_id['animal'], sess_id['session']))

    if data_type == 'neural':

        path = os.path.join(hparams['data_dir'], sess_id['lab'],
                            sess_id['expt'], sess_id['animal'],
                            sess_id['session'], 'data.hdf5')

        transforms_ = []

        # filter neural data by region
        if hparams.get('subsample_regions', 'none') != 'none':
            # get region and indices
            sampling = hparams['subsample_regions']
            region_name = hparams['region']
            regions = get_region_list(hparams)
            if sampling == 'single':
                idxs = regions[region_name]
            elif sampling == 'loo':
                idxs = []
                for reg_name, reg_idxs in regions.items():
                    if reg_name != region_name:
                        idxs.append(reg_idxs)
                idxs = np.concatenate(idxs)
            else:
                raise ValueError('"%s" is an invalid region sampling option' %
                                 sampling)
            transforms_.append(
                SelectIdxs(idxs, str('%s-%s' % (region_name, sampling))))

        # filter neural data by activity
        if hparams['neural_type'] == 'spikes':
            if hparams['neural_thresh'] > 0:
                transforms_.append(
                    Threshold(threshold=hparams['neural_thresh'],
                              bin_size=hparams['neural_bin_size']))
        elif hparams['neural_type'] == 'ca':
            if hparams['model_type'][-6:] != 'neural':
                # don't zscore if predicting calcium activity
                transforms_.append(ZScore())
        else:
            raise ValueError('"%s" is an invalid neural type' %
                             hparams['neural_type'])

        # compose filters
        if len(transforms_) == 0:
            transform = None
        else:
            transform = Compose(transforms_)

    elif data_type == 'ae_latents' or data_type == 'latents':
        ae_dir = get_expt_dir(hparams,
                              model_class='ae',
                              expt_name=hparams['ae_experiment_name'],
                              model_type=hparams['ae_model_type'])

        transform = None

        if 'ae_latents_file' in hparams:
            path = hparams['ae_latents_file']
        else:
            if 'ae_version' in hparams and isinstance(hparams['ae_version'],
                                                      int):
                ae_version = str('version_%i' % hparams['ae_version'])
            else:
                ae_version = 'version_%i' % get_best_model_version(
                    ae_dir, 'val_loss')[0]
            ae_latents = str('%slatents.pkl' % sess_id_str)
            path = os.path.join(ae_dir, ae_version, ae_latents)

    elif data_type == 'arhmm_states' or data_type == 'states':

        arhmm_dir = get_expt_dir(hparams,
                                 model_class='arhmm',
                                 expt_name=hparams['arhmm_experiment_name'])

        if hparams.get('shuffle_rng_seed') is not None:
            transform = BlockShuffle(hparams['shuffle_rng_seed'])
        else:
            transform = None

        if 'arhmm_state_file' in hparams:
            path = hparams['arhmm_state_file']
        else:
            if 'arhmm_version' in hparams and isinstance(
                    hparams['arhmm_version'], int):
                arhmm_version = str('version_%i' % hparams['arhmm_version'])
            else:
                arhmm_version = 'version_%i' % get_best_model_version(
                    arhmm_dir, 'val_loss', best_def='max')[0]
            arhmm_states = str('%sstates.pkl' % sess_id_str)
            path = os.path.join(arhmm_dir, arhmm_version, arhmm_states)

    elif data_type == 'neural_ae_predictions' or data_type == 'ae_predictions':

        neural_ae_dir = get_expt_dir(
            hparams,
            model_class='neural-ae',
            expt_name=hparams['neural_ae_experiment_name'],
            model_type=hparams['neural_ae_model_type'])

        transform = None
        if 'ae_predictions_file' in hparams:
            path = hparams['ae_predictions_file']
        else:
            if 'neural_ae_version' in hparams and isinstance(
                    hparams['neural_ae_version'], int):
                neural_ae_version = str('version_%i' %
                                        hparams['neural_ae_version'])
            else:
                neural_ae_version = 'version_%i' % get_best_model_version(
                    neural_ae_dir, 'val_loss')[0]
            neural_ae_predictions = str('%spredictions.pkl' % sess_id_str)
            path = os.path.join(neural_ae_dir, neural_ae_version,
                                neural_ae_predictions)

    elif data_type == 'neural_arhmm_predictions' or data_type == 'arhmm_predictions':

        neural_arhmm_dir = get_expt_dir(
            hparams,
            model_class='neural-arhmm',
            expt_name=hparams['neural_arhmm_experiment_name'],
            model_type=hparams['neural_arhmm_model_type'])

        transform = None
        if 'arhmm_predictions_file' in hparams:
            path = hparams['arhmm_predictions_file']
        else:
            if 'neural_arhmm_version' in hparams and \
                    isinstance(hparams['neural_arhmm_version'], int):
                neural_arhmm_version = str('version_%i' %
                                           hparams['neural_arhmm_version'])
            else:
                neural_arhmm_version = 'version_%i' % get_best_model_version(
                    neural_arhmm_dir, 'val_loss')[0]
            neural_arhmm_predictions = str('%spredictions.pkl' % sess_id_str)
            path = os.path.join(neural_arhmm_dir, neural_arhmm_version,
                                neural_arhmm_predictions)

    else:
        raise ValueError('"%s" is an invalid data_type' % data_type)

    return transform, path
Example #4
0
def load_metrics_csv_as_df(hparams,
                           lab,
                           expt,
                           metrics_list,
                           test=False,
                           version='best'):
    """Load metrics csv file and return as a pandas dataframe for easy plotting.

    Parameters
    ----------
    hparams : :obj:`dict`
        requires `sessions_csv`, `multisession`, `lab`, `expt`, `animal` and `session`
    lab : :obj:`str`
        for `get_lab_example`
    expt : :obj:`str`
        for `get_lab_example`
    metrics_list : :obj:`list`
        names of metrics to pull from csv; do not prepend with 'tr', 'val', or 'test'
    test : :obj:`bool`
        True to only return test values (computed once at end of training)
    version: :obj:`str`
        `best` to find best model in tt expt, None to find model with hyperparams defined in
        `hparams`, int to load specific model

    Returns
    -------
    :obj:`pandas.DataFrame` object

    """

    # programmatically fill out other hparams options
    get_lab_example(hparams, lab, expt)
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)

    # find metrics csv file
    if version is 'best':
        version = get_best_model_version(hparams['expt_dir'])[0]
    elif isinstance(version, int):
        version = version
    else:
        _, version = experiment_exists(hparams, which_version=True)
    version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % version)
    metric_file = os.path.join(version_dir, 'metrics.csv')
    metrics = pd.read_csv(metric_file)

    # collect data from csv file
    sess_ids = read_session_info_from_csv(
        os.path.join(version_dir, 'session_info.csv'))
    sess_ids_strs = []
    for sess_id in sess_ids:
        sess_ids_strs.append(
            str('%s/%s' % (sess_id['animal'], sess_id['session'])))
    metrics_df = []
    for i, row in metrics.iterrows():
        dataset = 'all' if row['dataset'] == -1 else sess_ids_strs[
            row['dataset']]
        if test:
            test_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'test'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **test_dict, 'loss': metric,
                            'val': row['test_%s' % metric]
                        },
                        index=[0]))
        else:
            # make dict for val data
            val_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'val'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **val_dict, 'loss': metric,
                            'val': row['val_%s' % metric]
                        },
                        index=[0]))
            # NOTE: grayed out lines are old version that returns a single dataframe row containing
            # all losses per epoch; new way creates one row per loss, making it easy to use with
            # seaborn's FacetGrid object for multi-axis plotting for metric in metrics_list:
            #     val_dict[metric] = row['val_%s' % metric]
            # metrics_df.append(pd.DataFrame(val_dict, index=[0]))
            # make dict for train data
            tr_dict = {
                'dataset': dataset,
                'epoch': row['epoch'],
                'dtype': 'train'
            }
            for metric in metrics_list:
                metrics_df.append(
                    pd.DataFrame(
                        {
                            **tr_dict, 'loss': metric,
                            'val': row['tr_%s' % metric]
                        },
                        index=[0]))
            # for metric in metrics_list:
            #     tr_dict[metric] = row['tr_%s' % metric]
            # metrics_df.append(pd.DataFrame(tr_dict, index=[0]))
    return pd.concat(metrics_df, sort=True)