def main(hparams, *args): if not isinstance(hparams, dict): hparams = vars(hparams) if hparams['model_type'] == 'conv': # blend outer hparams with architecture hparams hparams = {**hparams['architecture_params'], **hparams} # print hparams to console _print_hparams(hparams) if hparams['model_type'] == 'conv' and hparams['n_ae_latents'] > hparams[ 'max_latents']: raise ValueError( 'Number of latents higher than max latents, architecture will not work' ) # Start at random times (so test tube creates separate folders) np.random.seed(random.randint(0, 1000)) time.sleep(np.random.uniform(3)) # create test-tube experiment hparams, sess_ids, exp = create_tt_experiment(hparams) if hparams is None: print('Experiment exists! Aborting fit') return # build data generator data_generator = build_data_generator(hparams, sess_ids) # #################### # ### CREATE MODEL ### # #################### def set_n_labels(data_generator, hparams): data, _ = data_generator.next_batch('val') sh = data['labels'].shape hparams['n_labels'] = sh[2] # [1, n_t, n_labels] print('constructing model...', end='') torch.manual_seed(hparams['rng_seed_model']) torch_rng_seed = torch.get_rng_state() hparams['model_build_rng_seed'] = torch_rng_seed hparams['n_datasets'] = len(sess_ids) if hparams['model_class'] == 'ae': from behavenet.models import AE as Model elif hparams['model_class'] == 'vae': from behavenet.models import VAE as Model elif hparams['model_class'] == 'beta-tcvae': from behavenet.models import BetaTCVAE as Model elif hparams['model_class'] == 'ps-vae': from behavenet.models import PSVAE as Model set_n_labels(data_generator, hparams) elif hparams['model_class'] == 'msps-vae': from behavenet.models import MSPSVAE as Model set_n_labels(data_generator, hparams) elif hparams['model_class'] == 'cond-vae': from behavenet.models import ConditionalVAE as Model set_n_labels(data_generator, hparams) elif hparams['model_class'] == 'cond-ae': from behavenet.models import ConditionalAE as Model set_n_labels(data_generator, hparams) elif hparams['model_class'] == 'cond-ae-msp': from behavenet.models import AEMSP as Model set_n_labels(data_generator, hparams) else: raise NotImplementedError( 'The model class "%s" is not currently implemented' % hparams['model_class']) model = Model(hparams) model.to(hparams['device']) # load pretrained weights if specified model = load_pretrained_ae(model, hparams) # Parallelize over gpus if desired if hparams['n_parallel_gpus'] > 1: from behavenet.models import CustomDataParallel model = CustomDataParallel(model) model.version = exp.version torch_rng_seed = torch.get_rng_state() hparams['training_rng_seed'] = torch_rng_seed # save out hparams as csv and dict hparams['training_completed'] = False export_hparams(hparams, exp) print('done') # ################### # ### TRAIN MODEL ### # ################### print(model) fit(hparams, model, data_generator, exp, method='ae') # update hparams upon successful training hparams['training_completed'] = True export_hparams(hparams, exp) # get rid of unneeded logging info _clean_tt_dir(hparams) # export training plots if hparams['export_train_plots']: print('creating training plots...', end='') version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % hparams['version']) if hparams['model_class'] == 'msps-vae': from behavenet.plotting.cond_ae_utils import plot_mspsvae_training_curves save_file = os.path.join(version_dir, 'loss_training') plot_mspsvae_training_curves( hparams, alpha=hparams['ps_vae.alpha'], beta=hparams['ps_vae.beta'], delta=hparams['ps_vae.delta'], rng_seed_model=hparams['rng_seed_model'], n_latents=hparams['n_ae_latents'] - hparams['n_background'] - hparams['n_labels'], n_background=hparams['n_background'], n_labels=hparams['n_labels'], dtype='train', save_file=save_file, format='png', version_dir=version_dir) save_file = os.path.join(version_dir, 'loss_validation') plot_mspsvae_training_curves( hparams, alpha=hparams['ps_vae.alpha'], beta=hparams['ps_vae.beta'], delta=hparams['ps_vae.delta'], rng_seed_model=hparams['rng_seed_model'], n_latents=hparams['n_ae_latents'] - hparams['n_background'] - hparams['n_labels'], n_background=hparams['n_background'], n_labels=hparams['n_labels'], dtype='val', save_file=save_file, format='png', version_dir=version_dir) else: save_file = os.path.join(version_dir, 'loss_training') export_train_plots(hparams, 'train', save_file=save_file) save_file = os.path.join(version_dir, 'loss_validation') export_train_plots(hparams, 'val', save_file=save_file) print('done')
def main(hparams, *args): if not isinstance(hparams, dict): hparams = vars(hparams) if hparams['model_type'] == 'conv': # blend outer hparams with architecture hparams hparams = {**hparams, **hparams['architecture_params']} # print hparams to console _print_hparams(hparams) # Start at random times (so test tube creates separate folders) np.random.seed(random.randint(0, 1000)) time.sleep(np.random.uniform(1)) # create test-tube experiment hparams, sess_ids, exp = create_tt_experiment(hparams) if hparams is None: print('Experiment exists! Aborting fit') return # build data generator data_generator = build_data_generator(hparams, sess_ids) # #################### # ### CREATE MODEL ### # #################### print('constructing model...', end='') torch.manual_seed(hparams['rng_seed_model']) torch_rnd_seed = torch.get_rng_state() hparams['model_build_rnd_seed'] = torch_rnd_seed hparams['n_datasets'] = len(sess_ids) data, _ = data_generator.next_batch('train') sh = data['labels'].shape hparams['n_labels'] = sh[2] # [1, n_t, n_labels] model = ConvDecoder(hparams) model.to(hparams['device']) # Load pretrained weights if specified # model = load_pretrained_ae(model, hparams) model.version = exp.version torch_rnd_seed = torch.get_rng_state() hparams['training_rnd_seed'] = torch_rnd_seed # save out hparams as csv and dict hparams['training_completed'] = False export_hparams(hparams, exp) print('done') # ################### # ### TRAIN MODEL ### # ################### fit(hparams, model, data_generator, exp, method='conv-decoder') # export training plots if hparams['export_train_plots']: print('creating training plots...', end='') version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % hparams['version']) save_file = os.path.join(version_dir, 'loss_training') export_train_plots(hparams, 'train', save_file=save_file) save_file = os.path.join(version_dir, 'loss_validation') export_train_plots(hparams, 'val', save_file=save_file) print('done') # update hparams upon successful training hparams['training_completed'] = True export_hparams(hparams, exp) # get rid of unneeded logging info _clean_tt_dir(hparams)
def main(hparams, *args): if not isinstance(hparams, dict): hparams = vars(hparams) # print hparams to console _print_hparams(hparams) # Start at random times (so test tube creates separate folders) np.random.seed(random.randint(0, 1000)) time.sleep(np.random.uniform(1)) # create test-tube experiment hparams, sess_ids, exp = create_tt_experiment(hparams) if hparams is None: print('Experiment exists! Aborting fit') return # build data generator data_generator = build_data_generator(hparams, sess_ids) ex_trial = data_generator.datasets[0].batch_idxs['train'][0] i_sig = hparams['input_signal'] o_sig = hparams['output_signal'] if hparams['model_class'] == 'neural-arhmm': hparams['input_size'] = data_generator.datasets[0][ex_trial][ i_sig].shape[1] hparams['output_size'] = hparams['n_arhmm_states'] elif hparams['model_class'] == 'arhmm-neural': hparams['input_size'] = hparams['n_arhmm_states'] hparams['output_size'] = data_generator.datasets[0][ex_trial][ o_sig].shape[1] elif hparams['model_class'] == 'neural-ae': hparams['input_size'] = data_generator.datasets[0][ex_trial][ i_sig].shape[1] hparams['output_size'] = hparams['n_ae_latents'] elif hparams['model_class'] == 'neural-ae-me': hparams['input_size'] = data_generator.datasets[0][ex_trial][ i_sig].shape[1] hparams['output_size'] = hparams['n_ae_latents'] elif hparams['model_class'] == 'ae-neural': hparams['input_size'] = hparams['n_ae_latents'] hparams['output_size'] = data_generator.datasets[0][ex_trial][ o_sig].shape[1] elif hparams['model_class'] == 'neural-labels': hparams['input_size'] = data_generator.datasets[0][ex_trial][ i_sig].shape[1] hparams['output_size'] = hparams['n_labels'] elif hparams['model_class'] == 'labels-neural': hparams['input_size'] = hparams['n_labels'] hparams['output_size'] = data_generator.datasets[0][ex_trial][ o_sig].shape[1] else: raise ValueError('%s is an invalid model class' % hparams['model_class']) if hparams['model_class'] == 'neural-ae' or hparams['model_class'] == 'neural-ae' \ or hparams['model_class'] == 'ae-neural': hparams['ae_model_path'] = os.path.join( os.path.dirname(data_generator.datasets[0].paths['ae_latents'])) hparams['ae_model_latents_file'] = data_generator.datasets[0].paths[ 'ae_latents'] elif hparams['model_class'] == 'neural-arhmm' or hparams[ 'model_class'] == 'arhmm-neural': hparams['arhmm_model_path'] = os.path.dirname( data_generator.datasets[0].paths['arhmm_states']) hparams['arhmm_model_states_file'] = data_generator.datasets[0].paths[ 'arhmm_states'] # Store which AE was used for the ARHMM tags = pickle.load( open(os.path.join(hparams['arhmm_model_path'], 'meta_tags.pkl'), 'rb')) hparams['ae_model_latents_file'] = tags['ae_model_latents_file'] # #################### # ### CREATE MODEL ### # #################### print('constructing model...', end='') torch.manual_seed(hparams['rng_seed_model']) torch_rng_seed = torch.get_rng_state() hparams['model_build_rng_seed'] = torch_rng_seed model = Decoder(hparams) model.to(hparams['device']) model.version = exp.version torch_rng_seed = torch.get_rng_state() hparams['training_rng_seed'] = torch_rng_seed # save out hparams as csv and dict for easy reloading hparams['training_completed'] = False export_hparams(hparams, exp) print('done') # #################### # ### TRAIN MODEL ### # #################### fit(hparams, model, data_generator, exp, method='nll') # update hparams upon successful training hparams['training_completed'] = True export_hparams(hparams, exp) # get rid of unneeded logging info _clean_tt_dir(hparams)
def main(hparams): if not isinstance(hparams, dict): hparams = vars(hparams) if hparams['transitions'] == 'sticky' and hparams['kappa'] == 0: print('Cannot fit sticky transitions with kappa=0! Aborting fit') return if hparams['transitions'] != 'sticky' and hparams['kappa'] > 0: print('Cannot fit %s transitions with kappa>0! Aborting fit' % hparams['transitions']) return # print hparams to console _print_hparams(hparams) # start at random times (so test tube creates separate folders) np.random.seed(random.randint(0, 1000)) time.sleep(np.random.uniform(1)) # create test-tube experiment hparams, sess_ids, exp = create_tt_experiment(hparams) if hparams is None: print('Experiment exists! Aborting fit') return # build data generator data_generator = build_data_generator(hparams, sess_ids) # #################### # ### CREATE MODEL ### # #################### # get all latents in list n_datasets = len(data_generator) print('collecting observations from data generator...', end='') data_key = 'ae_latents' if hparams['model_class'].find('labels') > -1: data_key = 'labels' latents, trial_idxs = get_latent_arrays_by_dtype(data_generator, sess_idxs=list( range(n_datasets)), data_key=data_key) obs_dim = latents['train'][0].shape[1] hparams['total_train_length'] = np.sum( [z.shape[0] for z in latents['train']]) # get separated by dataset as well latents_sess = {d: None for d in range(n_datasets)} trial_idxs_sess = {d: None for d in range(n_datasets)} for d in range(n_datasets): latents_sess[d], trial_idxs_sess[d] = get_latent_arrays_by_dtype( data_generator, sess_idxs=d, data_key=data_key) print('done') if hparams['model_class'] == 'arhmm' or hparams['model_class'] == 'hmm': hparams['ae_model_path'] = os.path.join( os.path.dirname(data_generator.datasets[0].paths['ae_latents'])) hparams['ae_model_latents_file'] = data_generator.datasets[0].paths[ 'ae_latents'] if hparams['n_arhmm_lags'] > 0: if hparams['model_class'][:5] != 'arhmm': # 'arhmm' or 'arhmm-labels' raise ValueError( 'Must specify model_class as arhmm when using AR lags') else: if hparams['model_class'][:3] != 'hmm': # 'hmm' or 'hmm-labels' raise ValueError( 'Must specify model_class as hmm when using 0 AR lags') # determine observation model if hparams['noise_type'] == 'gaussian': if hparams['n_arhmm_lags'] > 0: obs_type = 'ar' else: obs_type = 'gaussian' elif hparams['noise_type'] == 'studentst': if hparams['n_arhmm_lags'] > 0: obs_type = 'robust_ar' else: obs_type = 'studentst' elif hparams['noise_type'] == 'diagonal_gaussian': if hparams['n_arhmm_lags'] > 0: obs_type = 'diagonal_ar' else: obs_type = 'diagonal_gaussian' elif hparams['noise_type'] == 'diagonal_studentst': if hparams['n_arhmm_lags'] > 0: obs_type = 'diagonal_robust_ar' else: obs_type = 'diagonal_studentst' else: raise ValueError('%s is not a valid noise type' % hparams['noise_type']) if hparams['n_arhmm_lags'] > 0: obs_kwargs = {'lags': hparams['n_arhmm_lags']} obs_init_kwargs = {'localize': True} else: obs_kwargs = None obs_init_kwargs = {} # determine transition model if hparams['transitions'] == 'stationary' or hparams[ 'transitions'] == 'standard': transitions = 'stationary' transition_kwargs = None elif hparams['transitions'] == 'sticky': transitions = 'sticky' transition_kwargs = {'kappa': hparams['kappa']} elif hparams['transitions'] == 'recurrent': transitions = 'recurrent' transition_kwargs = None elif hparams['transitions'] == 'recurrent_only': transitions = 'recurrent_only' transition_kwargs = None else: raise ValueError('%s is not a valid transition type' % hparams['transitions']) print('constructing model...', end='') np.random.seed(hparams['rng_seed_model']) hmm = ssm.HMM(hparams['n_arhmm_states'], obs_dim, observations=obs_type, observation_kwargs=obs_kwargs, transitions=transitions, transition_kwargs=transition_kwargs) hmm.initialize(latents['train']) hmm.observations.initialize(latents['train'], **obs_init_kwargs) # save out hparams as csv and dict hparams['training_completed'] = False export_hparams(hparams, exp) hmm.hparams = hparams print('done') # #################### # ### TRAIN MODEL ### # #################### # TODO: move fitting into own function # precompute normalizers n_datapoints = {} n_datapoints_sess = {} for dtype in {'train', 'val', 'test'}: n_datapoints[dtype] = np.vstack(latents[dtype]).size n_datapoints_sess[dtype] = {} for d in range(n_datasets): n_datapoints_sess[dtype][d] = np.vstack( latents_sess[d][dtype]).size val_ll_prev = np.inf tolerance = hparams.get('arhmm_es_tol', 0) # hmm.fit( # latents['train'], method='em', num_iters=hparams['n_iters'], initialize=False, # tolerance=tolerance) # epoch = hparams['n_iters'] for epoch in range(hparams['n_iters'] + 1): # Note: the 0th epoch has no training (randomly initialized model is evaluated) so we cycle # through `n_iters` training epochs print('epoch %03i/%03i' % (epoch, hparams['n_iters'])) if epoch > 0: hmm.fit(latents['train'], method='em', num_iters=1, initialize=False) # export aggregated metrics on train/val data tr_ll = -hmm.log_likelihood(latents['train']) / n_datapoints['train'] val_ll = -hmm.log_likelihood(latents['val']) / n_datapoints['val'] exp.log({ 'epoch': epoch, 'dataset': -1, 'tr_loss': tr_ll, 'val_loss': val_ll, 'trial': -1 }) # export individual session metrics on train/val data for d in range(data_generator.n_datasets): tr_ll = -hmm.log_likelihood( latents_sess[d]['train']) / n_datapoints_sess['train'][d] val_ll = -hmm.log_likelihood( latents_sess[d]['val']) / n_datapoints_sess['val'][d] exp.log({ 'epoch': epoch, 'dataset': d, 'tr_loss': tr_ll, 'val_loss': val_ll, 'trial': -1 }) # check for convergence if epoch > 10 and np.abs((val_ll - val_ll_prev) / val_ll) < tolerance: print( 'relative change less than tolerance=%1.2f; training terminating!' % tolerance) break val_ll_prev = val_ll # export individual session metrics on test data for d in range(n_datasets): for i, b in enumerate(trial_idxs_sess[d]['test']): n = latents_sess[d]['test'][i].size test_ll = -hmm.log_likelihood(latents_sess[d]['test'][i]) / n exp.log({ 'epoch': epoch, 'dataset': d, 'test_loss': test_ll, 'trial': b }) exp.save() # reconfigure model/states by usage zs = [hmm.most_likely_states(x) for x in latents['train']] usage = np.bincount(np.concatenate(zs), minlength=hmm.K) perm = np.argsort(usage)[::-1] hmm.permute(perm) # save model filepath = os.path.join(hparams['expt_dir'], 'version_%i' % exp.version, 'best_val_model.pt') with open(filepath, 'wb') as f: pickle.dump(hmm, f) # ###################### # ### EVALUATE ARHMM ### # ###################### # export states if hparams['export_states']: export_states(hparams, data_generator, hmm) # export training plots if hparams['export_train_plots']: print('creating training plots...', end='') version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % hparams['version']) save_file = os.path.join(version_dir, 'loss_training') export_train_plots(hparams, 'train', loss_type='ll', save_file=save_file) save_file = os.path.join(version_dir, 'loss_validation') export_train_plots(hparams, 'val', loss_type='ll', save_file=save_file) print('done') # update hparams upon successful training hparams['training_completed'] = True export_hparams(hparams, exp) # get rid of unneeded logging info _clean_tt_dir(hparams)