def real_vs_sampled_wrapper(output_type, hparams, save_file, sess_idx, dtype='test', conditional=True, max_frames=400, frame_rate=20, n_buffer=5, xtick_locs=None, frame_rate_beh=None, format='png'): """Produce movie with (AE) reconstructed video and sampled video. This is a high-level function that loads the model described in the hparams dictionary and produces the necessary state sequences/samples. The sampled video can be completely unconditional (states and latents are sampled) or conditioned on the most likely state sequence. Parameters ---------- output_type : :obj:`str` 'plot' | 'movie' | 'both' hparams : :obj:`dict` needs to contain enough information to specify an autoencoder save_file : :obj:`str` full save file (path and filename) sess_idx : :obj:`int`, optional session index into data generator dtype : :obj:`str`, optional types of trials to make plot/video with; 'train' | 'val' | 'test' conditional : :obj:`bool` conditional vs unconditional samples; for creating reconstruction title max_frames : :obj:`int`, optional maximum number of frames to animate frame_rate : :obj:`float`, optional frame rate of saved movie n_buffer : :obj:`int` number of blank frames between animated trials if more one are needed to reach :obj:`max_frames` xtick_locs : :obj:`array-like`, optional tick locations in bin values for plot frame_rate_beh : :obj:`float`, optional behavioral video framerate; to properly relabel xticks format : :obj:`str`, optional any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' Returns ------- :obj:`matplotlib.figure.Figure` matplotlib figure handle if :obj:`output_type='plot'` or :obj:`output_type='both'`, else nothing returned (movie is saved) """ from behavenet.data.utils import get_transforms_paths from behavenet.fitting.utils import get_expt_dir from behavenet.fitting.utils import get_session_dir # check input - cannot create sampled movies for arhmm-labels models (no mapping from labels to # frames) if hparams['model_class'].find('labels') > -1: if output_type == 'both' or output_type == 'movie': print( 'warning: cannot create video with "arhmm-labels" model; producing plots' ) output_type = 'plot' # load latents and states (observed and sampled) model_output = get_model_latents_states(hparams, '', sess_idx=sess_idx, return_samples=50, cond_sampling=conditional) if output_type == 'both' or output_type == 'movie': # load in AE decoder if hparams.get('ae_model_path', None) is not None: ae_model_file = os.path.join(hparams['ae_model_path'], 'best_val_model.pt') ae_arch = pickle.load( open(os.path.join(hparams['ae_model_path'], 'meta_tags.pkl'), 'rb')) else: hparams['session_dir'], sess_ids = get_session_dir(hparams) hparams['expt_dir'] = get_expt_dir(hparams) _, latents_file = get_transforms_paths('ae_latents', hparams, sess_ids[sess_idx]) ae_model_file = os.path.join(os.path.dirname(latents_file), 'best_val_model.pt') ae_arch = pickle.load( open( os.path.join(os.path.dirname(latents_file), 'meta_tags.pkl'), 'rb')) print('loading model from %s' % ae_model_file) ae_model = AE(ae_arch) ae_model.load_state_dict( torch.load(ae_model_file, map_location=lambda storage, loc: storage)) ae_model.eval() n_channels = ae_model.hparams['n_input_channels'] y_pix = ae_model.hparams['y_pixels'] x_pix = ae_model.hparams['x_pixels'] # push observed latents through ae decoder ims_recon = np.zeros((0, n_channels * y_pix, x_pix)) i_trial = 0 while ims_recon.shape[0] < max_frames: recon = ae_model.decoding( torch.tensor(model_output['latents'][dtype][i_trial]).float(), None, None). \ cpu().detach().numpy() recon = np.concatenate( [recon[:, i] for i in range(recon.shape[1])], axis=1) zero_frames = np.zeros((n_buffer, n_channels * y_pix, x_pix)) # add a few black frames ims_recon = np.concatenate((ims_recon, recon, zero_frames), axis=0) i_trial += 1 # push sampled latents through ae decoder ims_recon_samp = np.zeros((0, n_channels * y_pix, x_pix)) i_trial = 0 while ims_recon_samp.shape[0] < max_frames: recon = ae_model.decoding( torch.tensor(model_output['latents_gen'][i_trial]).float(), None, None).cpu().detach().numpy() recon = np.concatenate( [recon[:, i] for i in range(recon.shape[1])], axis=1) zero_frames = np.zeros((n_buffer, n_channels * y_pix, x_pix)) # add a few black frames ims_recon_samp = np.concatenate( (ims_recon_samp, recon, zero_frames), axis=0) i_trial += 1 make_real_vs_sampled_movies(ims_recon, ims_recon_samp, conditional=conditional, save_file=save_file, frame_rate=frame_rate) if output_type == 'both' or output_type == 'plot': i_trial = 0 latents = model_output['latents'][dtype][i_trial][:max_frames] states = model_output['states'][dtype][i_trial][:max_frames] latents_samp = model_output['latents_gen'][i_trial][:max_frames] if not conditional: states_samp = model_output['states_gen'][i_trial][:max_frames] else: states_samp = [] fig = plot_real_vs_sampled(latents, latents_samp, states, states_samp, save_file=save_file, xtick_locs=xtick_locs, frame_rate=frame_rate_beh, format=format) if output_type == 'movie': return None elif output_type == 'both' or output_type == 'plot': return fig else: raise ValueError('"%s" is an invalid output_type' % output_type)
def make_syllable_movies_wrapper(hparams, save_file, sess_idx=0, dtype='test', max_frames=400, frame_rate=10, min_threshold=0, n_buffer=5, n_pre_frames=3, n_rows=None, single_syllable=None): """Present video clips of each individual syllable in separate panels. This is a high-level function that loads the arhmm model described in the hparams dictionary and produces the necessary states/video frames. Parameters ---------- hparams : :obj:`dict` needs to contain enough information to specify an arhmm save_file : :obj:`str` full save file (path and filename) sess_idx : :obj:`int`, optional session index into data generator dtype : :obj:`str`, optional types of trials to make video with; 'train' | 'val' | 'test' max_frames : :obj:`int`, optional maximum number of frames to animate frame_rate : :obj:`float`, optional frame rate of saved movie min_threshold : :obj:`int`, optional minimum number of frames in a syllable run to be considered for movie n_buffer : :obj:`int` number of blank frames between syllable instances n_pre_frames : :obj:`int` number of behavioral frames to precede each syllable instance n_rows : :obj:`int` or :obj:`NoneType` number of rows in output movie single_syllable : :obj:`int` or :obj:`NoneType` choose only a single state for movie """ from behavenet.data.data_generator import ConcatSessionsGenerator from behavenet.data.utils import get_data_generator_inputs from behavenet.data.utils import get_transforms_paths from behavenet.fitting.utils import experiment_exists from behavenet.fitting.utils import get_expt_dir from behavenet.fitting.utils import get_session_dir # load images, latents, and states hparams['session_dir'], sess_ids = get_session_dir(hparams) hparams['expt_dir'] = get_expt_dir(hparams) hparams['load_videos'] = True hparams, signals, transforms, paths = get_data_generator_inputs( hparams, sess_ids) data_generator = ConcatSessionsGenerator( hparams['data_dir'], sess_ids, signals_list=[signals[sess_idx]], transforms_list=[transforms[sess_idx]], paths_list=[paths[sess_idx]], device='cpu', as_numpy=True, batch_load=False, rng_seed=hparams['rng_seed_data']) ims_orig = data_generator.datasets[sess_idx].data['images'] del data_generator # free up memory # get tt version number _, version = experiment_exists(hparams, which_version=True) print('producing syllable videos for arhmm %s' % version) # load latents/labels if hparams['model_class'].find('labels') > -1: from behavenet.data.utils import load_labels_like_latents latents = load_labels_like_latents(hparams, sess_ids, sess_idx) else: _, latents_file = get_transforms_paths('ae_latents', hparams, sess_ids[sess_idx]) with open(latents_file, 'rb') as f: latents = pickle.load(f) trial_idxs = latents['trials'][dtype] # load model model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version, 'best_val_model.pt') with open(model_file, 'rb') as f: hmm = pickle.load(f) # infer discrete states states = [ hmm.most_likely_states(latents['latents'][s]) for s in latents['trials'][dtype] ] if len(states) == 0: raise ValueError('No latents for dtype=%s' % dtype) # find runs of discrete states; state indices is a list, each entry of which is a np array with # shape (n_state_instances, 3), where the 3 values are: # chunk_idx, chunk_start_idx, chunk_end_idx # chunk_idx is in [0, n_chunks], and indexes trial_idxs state_indices = get_discrete_chunks(states, include_edges=True) K = len(state_indices) # get all example over minimum state length threshold over_threshold_instances = [[] for _ in range(K)] for i_state in range(K): if state_indices[i_state].shape[0] > 0: state_lens = np.diff(state_indices[i_state][:, 1:3], axis=1) over_idxs = state_lens > min_threshold over_threshold_instances[i_state] = state_indices[i_state][ over_idxs[:, 0]] np.random.shuffle( over_threshold_instances[i_state]) # shuffle instances make_syllable_movies(ims_orig=ims_orig, state_list=over_threshold_instances, trial_idxs=trial_idxs, save_file=save_file, max_frames=max_frames, frame_rate=frame_rate, n_buffer=n_buffer, n_pre_frames=n_pre_frames, n_rows=n_rows, single_syllable=single_syllable)
def plot_neural_reconstruction_traces_wrapper(hparams, save_file=None, trial=None, xtick_locs=None, frame_rate=None, format='png', **kwargs): """Plot ae latents and their neural reconstructions. This is a high-level function that loads the model described in the hparams dictionary and produces the necessary predicted latents. Parameters ---------- hparams : :obj:`dict` needs to contain enough information to specify an ae latent decoder save_file : :obj:`str` full save file (path and filename) trial : :obj:`int`, optional if :obj:`NoneType`, use first test trial xtick_locs : :obj:`array-like`, optional tick locations in units of bins frame_rate : :obj:`float`, optional frame rate of behavorial video; to properly relabel xticks format : :obj:`str`, optional any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' Returns ------- :obj:`matplotlib.figure.Figure` matplotlib figure handle of plot """ # find good trials import copy from behavenet.data.utils import get_transforms_paths from behavenet.data.data_generator import ConcatSessionsGenerator # ae data hparams_ae = copy.copy(hparams) hparams_ae['experiment_name'] = hparams['ae_experiment_name'] hparams_ae['model_class'] = hparams['ae_model_class'] hparams_ae['model_type'] = hparams['ae_model_type'] ae_transform, ae_path = get_transforms_paths('ae_latents', hparams_ae, None) # ae predictions data hparams_dec = copy.copy(hparams) hparams_dec['neural_ae_experiment_name'] = hparams[ 'decoder_experiment_name'] hparams_dec['neural_ae_model_class'] = hparams['decoder_model_class'] hparams_dec['neural_ae_model_type'] = hparams['decoder_model_type'] ae_pred_transform, ae_pred_path = get_transforms_paths( 'neural_ae_predictions', hparams_dec, None) signals = ['ae_latents', 'ae_predictions'] transforms = [ae_transform, ae_pred_transform] paths = [ae_path, ae_pred_path] data_generator = ConcatSessionsGenerator(hparams['data_dir'], [hparams], signals_list=[signals], transforms_list=[transforms], paths_list=[paths], device='cpu', as_numpy=False, batch_load=True, rng_seed=0) if trial is None: # choose first test trial trial = data_generator.datasets[0].batch_idxs['test'][0] batch = data_generator.datasets[0][trial] traces_ae = batch['ae_latents'].cpu().detach().numpy() traces_neural = batch['ae_predictions'].cpu().detach().numpy() n_max_lags = hparams.get('n_max_lags', 0) # only plot valid segment of data if n_max_lags > 0: fig = plot_neural_reconstruction_traces( traces_ae[n_max_lags:-n_max_lags], traces_neural[n_max_lags:-n_max_lags], save_file, xtick_locs, frame_rate, format, **kwargs) else: fig = plot_neural_reconstruction_traces(traces_ae, traces_neural, save_file, xtick_locs, frame_rate, format, **kwargs) return fig
def get_model_latents_states(hparams, version, sess_idx=0, return_samples=0, cond_sampling=False, dtype='test', dtypes=['train', 'val', 'test'], rng_seed=0): """Return arhmm defined in :obj:`hparams` with associated latents and states. Can also return sampled latents and states. Parameters ---------- hparams : :obj:`dict` needs to contain enough information to specify an arhmm version : :obj:`str` or :obj:`int` test tube model version (can be 'best') sess_idx : :obj:`int`, optional session index into data generator return_samples : :obj:`int`, optional number of trials to sample from model cond_sampling : :obj:`bool`, optional if :obj:`True` return samples conditioned on most likely state sequence; else return unconditioned samples dtype : :obj:`str`, optional trial type to use for conditonal sampling; 'train' | 'val' | 'test' dtypes : :obj:`array-like`, optional trial types for which to collect latents and states rng_seed : :obj:`int`, optional random number generator seed to control sampling Returns ------- :obj:`dict` - 'model' (:obj:`ssm.HMM` object) - 'latents' (:obj:`dict`): latents from train, val and test trials - 'states' (:obj:`dict`): states from train, val and test trials - 'trial_idxs' (:obj:`dict`): trial indices from train, val and test trials - 'latents_gen' (:obj:`list`) - 'states_gen' (:obj:`list`) """ from behavenet.data.utils import get_transforms_paths from behavenet.fitting.utils import experiment_exists from behavenet.fitting.utils import get_best_model_version from behavenet.fitting.utils import get_expt_dir from behavenet.fitting.utils import get_session_dir hparams['session_dir'], sess_ids = get_session_dir(hparams) hparams['expt_dir'] = get_expt_dir(hparams) # get version/model if version == 'best': version = get_best_model_version(hparams['expt_dir'], measure='val_loss', best_def='max')[0] else: _, version = experiment_exists(hparams, which_version=True) if version is None: raise FileNotFoundError( 'Could not find the specified model version in %s' % hparams['expt_dir']) # load model model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version, 'best_val_model.pt') with open(model_file, 'rb') as f: hmm = pickle.load(f) # load latents/labels if hparams['model_class'].find('labels') > -1: from behavenet.data.utils import load_labels_like_latents all_latents = load_labels_like_latents(hparams, sess_ids, sess_idx) else: _, latents_file = get_transforms_paths('ae_latents', hparams, sess_ids[sess_idx]) with open(latents_file, 'rb') as f: all_latents = pickle.load(f) # collect inferred latents/states trial_idxs = {} latents = {} states = {} for data_type in dtypes: trial_idxs[data_type] = all_latents['trials'][data_type] latents[data_type] = [ all_latents['latents'][i_trial] for i_trial in trial_idxs[data_type] ] states[data_type] = [ hmm.most_likely_states(x) for x in latents[data_type] ] # collect sampled latents/states if return_samples > 0: states_gen = [] np.random.seed(rng_seed) if cond_sampling: n_latents = latents[dtype][0].shape[1] latents_gen = [ np.zeros((len(state_seg), n_latents)) for state_seg in states[dtype] ] for i_seg, state_seg in enumerate(states[dtype]): for i_t in range(len(state_seg)): if i_t >= 1: latents_gen[i_seg][i_t] = hmm.observations.sample_x( states[dtype][i_seg][i_t], latents_gen[i_seg][:i_t], input=np.zeros(0)) else: latents_gen[i_seg][i_t] = hmm.observations.sample_x( states[dtype][i_seg][i_t], latents[dtype][i_seg][0].reshape((1, n_latents)), input=np.zeros(0)) else: latents_gen = [] offset = 200 for i in range(return_samples): these_states_gen, these_latents_gen = hmm.sample( latents[dtype][0].shape[0] + offset) latents_gen.append(these_latents_gen[offset:]) states_gen.append(these_states_gen[offset:]) else: latents_gen = [] states_gen = [] return_dict = { 'model': hmm, 'latents': latents, 'states': states, 'trial_idxs': trial_idxs, 'latents_gen': latents_gen, 'states_gen': states_gen, } return return_dict
def test_get_transforms_paths(): hparams = { 'data_dir': 'ddir', 'results_dir': 'rdir', 'lab': 'lab', 'expt': 'expt', 'animal': 'animal', 'session': 'session' } session_dir = os.path.join(hparams['data_dir'], hparams['lab'], hparams['expt'], hparams['animal'], hparams['session']) hdf5_path = os.path.join(session_dir, 'data.hdf5') sess_id_str = str('%s_%s_%s_%s_' % (hparams['lab'], hparams['expt'], hparams['animal'], hparams['session'])) # ------------------------ # neural data # ------------------------ # spikes, no thresholding hparams['neural_type'] = 'spikes' hparams['neural_thresh'] = 0 transform, path = utils.get_transforms_paths('neural', hparams, sess_id=None, check_splits=False) assert path == hdf5_path assert transform is None # spikes, thresholding hparams['neural_type'] = 'spikes' hparams['neural_thresh'] = 1 hparams['neural_bin_size'] = 1 transform, path = utils.get_transforms_paths('neural', hparams, sess_id=None, check_splits=False) assert path == hdf5_path assert transform.__repr__().find('Threshold') > -1 # calcium, no zscoring hparams['neural_type'] = 'ca' hparams['model_type'] = 'ae-neural' transform, path = utils.get_transforms_paths('neural', hparams, sess_id=None, check_splits=False) assert path == hdf5_path assert transform is None # calcium, zscoring hparams['neural_type'] = 'ca' hparams['model_type'] = 'neural-ae' transform, path = utils.get_transforms_paths('neural', hparams, sess_id=None, check_splits=False) assert path == hdf5_path assert transform.__repr__().find('ZScore') > -1 # raise exception for incorrect neural type hparams['neural_type'] = 'wf' with pytest.raises(ValueError): utils.get_transforms_paths('neural', hparams, sess_id=None, check_splits=False) # TODO: test subsampling methods # ------------------------ # ae latents # ------------------------ hparams['session_dir'] = session_dir hparams['ae_model_class'] = 'ae' hparams['ae_model_type'] = 'conv' hparams['n_ae_latents'] = 8 hparams['ae_experiment_name'] = 'tt_expt_ae' hparams['ae_version'] = 0 ae_path = os.path.join(hparams['data_dir'], hparams['lab'], hparams['expt'], hparams['animal'], hparams['session'], hparams['ae_model_class'], hparams['ae_model_type'], '%02i_latents' % hparams['n_ae_latents'], hparams['ae_experiment_name']) # user-defined latent path hparams['ae_latents_file'] = 'path/to/latents' transform, path = utils.get_transforms_paths('ae_latents', hparams, sess_id=None, check_splits=False) assert path == hparams['ae_latents_file'] assert transform is None hparams.pop('ae_latents_file') # build pathname from hparams transform, path = utils.get_transforms_paths('ae_latents', hparams, sess_id=None, check_splits=False) assert path == os.path.join(ae_path, 'version_%i' % hparams['ae_version'], '%slatents.pkl' % sess_id_str) assert transform is None # get correct transform transform, path = utils.get_transforms_paths('ae_latents_me', hparams, sess_id=None, check_splits=False) assert path == os.path.join(ae_path, 'version_%i' % hparams['ae_version'], '%slatents.pkl' % sess_id_str) assert transform.__repr__().find('MotionEnergy') > -1 # TODO: use get_best_model_version() # ------------------------ # arhmm states # ------------------------ hparams['n_ae_latents'] = 8 hparams['n_arhmm_states'] = 2 hparams['transitions'] = 'stationary' hparams['noise_type'] = 'gaussian' hparams['arhmm_experiment_name'] = 'tt_expt_arhmm' hparams['arhmm_version'] = 1 arhmm_path = os.path.join(hparams['data_dir'], hparams['lab'], hparams['expt'], hparams['animal'], hparams['session'], 'arhmm', '%02i_latents' % hparams['n_ae_latents'], '%02i_states' % hparams['n_arhmm_states'], hparams['transitions'], hparams['noise_type'], hparams['arhmm_experiment_name']) # user-defined state path hparams['arhmm_states_file'] = 'path/to/states' transform, path = utils.get_transforms_paths('arhmm_states', hparams, sess_id=None, check_splits=False) assert path == hparams['arhmm_states_file'] assert transform is None hparams.pop('arhmm_states_file') # build path name from hparams transform, path = utils.get_transforms_paths('arhmm_states', hparams, sess_id=None, check_splits=False) assert path == os.path.join(arhmm_path, 'version_%i' % hparams['arhmm_version'], '%sstates.pkl' % sess_id_str) assert transform is None # include shuffle transform hparams['shuffle_rng_seed'] = 0 transform, path = utils.get_transforms_paths('arhmm_states', hparams, sess_id=None, check_splits=False) assert path == os.path.join(arhmm_path, 'version_%i' % hparams['arhmm_version'], '%sstates.pkl' % sess_id_str) assert transform.__repr__().find('BlockShuffle') > -1 # TODO: use get_best_model_version() # ------------------------ # neural ae predictions # ------------------------ hparams['n_ae_latents'] = 8 hparams['neural_ae_model_type'] = 'linear' hparams['neural_ae_experiment_name'] = 'tt_expt_ae_decoder' hparams['neural_ae_version'] = 2 ae_pred_path = os.path.join(hparams['data_dir'], hparams['lab'], hparams['expt'], hparams['animal'], hparams['session'], 'neural-ae', '%02i_latents' % hparams['n_ae_latents'], hparams['neural_ae_model_type'], 'all', hparams['neural_ae_experiment_name']) # user-defined predictions path hparams['ae_predictions_file'] = 'path/to/predictions' transform, path = utils.get_transforms_paths('neural_ae_predictions', hparams, sess_id=None, check_splits=False) assert path == hparams['ae_predictions_file'] assert transform is None hparams.pop('ae_predictions_file') # build pathname from hparams transform, path = utils.get_transforms_paths('neural_ae_predictions', hparams, sess_id=None, check_splits=False) assert path == os.path.join(ae_pred_path, 'version_%i' % hparams['neural_ae_version'], '%spredictions.pkl' % sess_id_str) assert transform is None # TODO: use get_best_model_version() # ------------------------ # neural arhmm predictions # ------------------------ hparams['n_ae_latents'] = 8 hparams['n_arhmm_states'] = 10 hparams['transitions'] = 'stationary' hparams['noise_type'] = 'studentst' hparams['neural_arhmm_model_type'] = 'linear' hparams['neural_arhmm_experiment_name'] = 'tt_expt_ae_decoder' hparams['neural_arhmm_version'] = 3 arhmm_pred_path = os.path.join(hparams['data_dir'], hparams['lab'], hparams['expt'], hparams['animal'], hparams['session'], 'neural-arhmm', '%02i_latents' % hparams['n_ae_latents'], '%02i_states' % hparams['n_arhmm_states'], hparams['transitions'], hparams['neural_arhmm_model_type'], 'all', hparams['neural_arhmm_experiment_name']) # user-defined predictions path hparams['arhmm_predictions_file'] = 'path/to/predictions' transform, path = utils.get_transforms_paths('neural_arhmm_predictions', hparams, sess_id=None, check_splits=False) assert path == hparams['arhmm_predictions_file'] assert transform is None hparams.pop('arhmm_predictions_file') # build pathname from hparams transform, path = utils.get_transforms_paths('neural_arhmm_predictions', hparams, sess_id=None, check_splits=False) assert path == os.path.join(arhmm_pred_path, 'version_%i' % hparams['neural_arhmm_version'], '%spredictions.pkl' % sess_id_str) assert transform is None # TODO: use get_best_model_version() # ------------------------ # other # ------------------------ with pytest.raises(ValueError): utils.get_transforms_paths('invalid', hparams, sess_id=None, check_splits=False)