def train_cleitcs(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader, metric_name, seed, **kwargs): """ :param s_dataloaders: :param t_dataloaders: :param kwargs: :return: """ s_train_dataloader = s_dataloaders t_train_dataloader = t_dataloaders autoencoder = AE(input_dim=kwargs['input_dim'], latent_dim=kwargs['latent_dim'], hidden_dims=kwargs['encoder_hidden_dims'], dop=kwargs['dop']).to(kwargs['device']) # get reference encoder aux_ae = deepcopy(autoencoder) aux_ae.encoder.load_state_dict(torch.load(os.path.join('./model_save/ae', f'ft_encoder_{seed}.pt'))) print('reference encoder loaded') reference_encoder = aux_ae.encoder # construct transmitter transmitter = MLP(input_dim=kwargs['latent_dim'], output_dim=kwargs['latent_dim'], hidden_dims=[kwargs['latent_dim']]).to(kwargs['device']) encoder = autoencoder.encoder target_decoder = MoMLP(input_dim=kwargs['latent_dim'], output_dim=kwargs['output_dim'], hidden_dims=kwargs['regressor_hidden_dims'], out_fn=torch.nn.Sigmoid).to(kwargs['device']) target_regressor = EncoderDecoder(encoder=encoder, decoder=target_decoder).to(kwargs['device']) train_history = defaultdict(list) # ae_eval_train_history = defaultdict(list) val_history = defaultdict(list) s_target_regression_eval_train_history = defaultdict(list) t_target_regression_eval_train_history = defaultdict(list) target_regression_eval_val_history = defaultdict(list) target_regression_eval_test_history = defaultdict(list) cleit_params = [ target_regressor.parameters(), transmitter.parameters() ] model_optimizer = torch.optim.AdamW(chain(*cleit_params), lr=kwargs['lr']) for epoch in range(int(kwargs['train_num_epochs'])): if epoch % 50 == 0: print(f'Coral training epoch {epoch}') for step, s_batch in enumerate(s_train_dataloader): t_batch = next(iter(t_train_dataloader)) train_history = cleit_train_step(model=target_regressor, transmitter=transmitter, reference_encoder=reference_encoder, s_batch=s_batch, t_batch=t_batch, device=kwargs['device'], optimizer=model_optimizer, alpha=kwargs['alpha'], history=train_history) s_target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor, dataloader=s_train_dataloader, device=kwargs['device'], history=s_target_regression_eval_train_history) t_target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor, dataloader=t_train_dataloader, device=kwargs['device'], history=t_target_regression_eval_train_history) target_regression_eval_val_history = evaluate_target_regression_epoch(regressor=target_regressor, dataloader=val_dataloader, device=kwargs['device'], history=target_regression_eval_val_history) target_regression_eval_test_history = evaluate_target_regression_epoch(regressor=target_regressor, dataloader=test_dataloader, device=kwargs['device'], history=target_regression_eval_test_history) save_flag, stop_flag = model_save_check(history=target_regression_eval_val_history, metric_name=metric_name, tolerance_count=50) if save_flag: torch.save(target_regressor.state_dict(), os.path.join(kwargs['model_save_folder'], f'cleitcs_regressor_{seed}.pt')) if stop_flag: break target_regressor.load_state_dict( torch.load(os.path.join(kwargs['model_save_folder'], f'cleitcs_regressor_{seed}.pt'))) # evaluate_target_regression_epoch(regressor=target_regressor, # dataloader=val_dataloader, # device=kwargs['device'], # history=None, # seed=seed, # output_folder=kwargs['model_save_folder']) evaluate_target_regression_epoch(regressor=target_regressor, dataloader=test_dataloader, device=kwargs['device'], history=None, seed=seed, output_folder=kwargs['model_save_folder']) return target_regressor, ( train_history, s_target_regression_eval_train_history, t_target_regression_eval_train_history, target_regression_eval_val_history, target_regression_eval_test_history)
def fine_tune_encoder(encoder, train_dataloader, val_dataloader, seed, task_save_folder, test_dataloader=None, metric_name='cpearsonr', normalize_flag=False, **kwargs): target_decoder = MoMLP(input_dim=kwargs['latent_dim'], output_dim=kwargs['output_dim'], hidden_dims=kwargs['regressor_hidden_dims'], out_fn=torch.nn.Sigmoid).to(kwargs['device']) target_regressor = EncoderDecoder(encoder=encoder, decoder=target_decoder, normalize_flag=normalize_flag).to( kwargs['device']) target_regression_train_history = defaultdict(list) target_regression_eval_train_history = defaultdict(list) target_regression_eval_val_history = defaultdict(list) target_regression_eval_test_history = defaultdict(list) encoder_module_indices = [ i for i in range(len(list(encoder.modules()))) if str(list(encoder.modules())[i]).startswith('Linear') ] reset_count = 1 lr = kwargs['lr'] target_regression_params = [target_regressor.decoder.parameters()] target_regression_optimizer = torch.optim.AdamW( chain(*target_regression_params), lr=lr) for epoch in range(kwargs['train_num_epochs']): if epoch % 50 == 0: print(f'Fine tuning epoch {epoch}') for step, batch in enumerate(train_dataloader): target_regression_train_history = regression_train_step( model=target_regressor, batch=batch, device=kwargs['device'], optimizer=target_regression_optimizer, history=target_regression_train_history) target_regression_eval_train_history = evaluate_target_regression_epoch( regressor=target_regressor, dataloader=train_dataloader, device=kwargs['device'], history=target_regression_eval_train_history) target_regression_eval_val_history = evaluate_target_regression_epoch( regressor=target_regressor, dataloader=val_dataloader, device=kwargs['device'], history=target_regression_eval_val_history) if test_dataloader is not None: target_regression_eval_test_history = evaluate_target_regression_epoch( regressor=target_regressor, dataloader=test_dataloader, device=kwargs['device'], history=target_regression_eval_test_history) save_flag, stop_flag = model_save_check( history=target_regression_eval_val_history, metric_name=metric_name, tolerance_count=10, reset_count=reset_count) if save_flag: torch.save( target_regressor.state_dict(), os.path.join(task_save_folder, f'target_regressor_{seed}.pt')) torch.save(target_regressor.encoder.state_dict(), os.path.join(task_save_folder, f'ft_encoder_{seed}.pt')) if stop_flag: try: ind = encoder_module_indices.pop() print(f'Unfreezing {epoch}') target_regressor.load_state_dict( torch.load( os.path.join(task_save_folder, f'target_regressor_{seed}.pt'))) target_regression_params.append( list(target_regressor.encoder.modules())[ind].parameters()) lr = lr * kwargs['decay_coefficient'] target_regression_optimizer = torch.optim.AdamW( chain(*target_regression_params), lr=lr) reset_count += 1 except IndexError: break target_regressor.load_state_dict( torch.load( os.path.join(task_save_folder, f'target_regressor_{seed}.pt'))) evaluate_target_regression_epoch(regressor=target_regressor, dataloader=val_dataloader, device=kwargs['device'], history=None, seed=seed, cv_flag=True, output_folder=kwargs['model_save_folder']) evaluate_target_regression_epoch(regressor=target_regressor, dataloader=test_dataloader, device=kwargs['device'], history=None, seed=seed, output_folder=kwargs['model_save_folder']) return target_regressor, (target_regression_train_history, target_regression_eval_train_history, target_regression_eval_val_history, target_regression_eval_test_history)
def train_dann(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader, metric_name, seed, **kwargs): """ :param s_dataloaders: :param t_dataloaders: :param kwargs: :return: """ s_train_dataloader = s_dataloaders t_train_dataloader = t_dataloaders autoencoder = AE(input_dim=kwargs['input_dim'], latent_dim=kwargs['latent_dim'], hidden_dims=kwargs['encoder_hidden_dims'], dop=kwargs['dop']).to(kwargs['device']) encoder = autoencoder.encoder target_decoder = MoMLP(input_dim=kwargs['latent_dim'], output_dim=kwargs['output_dim'], hidden_dims=kwargs['regressor_hidden_dims'], out_fn=torch.nn.Sigmoid).to(kwargs['device']) target_regressor = EncoderDecoder( encoder=encoder, decoder=target_decoder).to(kwargs['device']) classifier = MLP(input_dim=kwargs['latent_dim'], output_dim=1, hidden_dims=kwargs['classifier_hidden_dims'], dop=kwargs['dop'], out_fn=torch.nn.Sigmoid).to(kwargs['device']) confounder_classifier = EncoderDecoder(encoder=autoencoder.encoder, decoder=classifier).to( kwargs['device']) train_history = defaultdict(list) s_target_regression_eval_train_history = defaultdict(list) t_target_regression_eval_train_history = defaultdict(list) target_regression_eval_val_history = defaultdict(list) target_regression_eval_test_history = defaultdict(list) confounded_loss = nn.BCEWithLogitsLoss() dann_params = [ target_regressor.parameters(), confounder_classifier.decoder.parameters() ] dann_optimizer = torch.optim.AdamW(chain(*dann_params), lr=kwargs['lr']) # start alternative training for epoch in range(int(kwargs['train_num_epochs'])): if epoch % 50 == 0: print(f'DANN training epoch {epoch}') # start autoencoder training epoch for step, s_batch in enumerate(s_train_dataloader): t_batch = next(iter(t_train_dataloader)) train_history = dann_train_step(classifier=confounder_classifier, model=target_regressor, s_batch=s_batch, t_batch=t_batch, loss_fn=confounded_loss, alpha=kwargs['alpha'], device=kwargs['device'], optimizer=dann_optimizer, history=train_history, scheduler=None) s_target_regression_eval_train_history = evaluate_target_regression_epoch( regressor=target_regressor, dataloader=s_train_dataloader, device=kwargs['device'], history=s_target_regression_eval_train_history) t_target_regression_eval_train_history = evaluate_target_regression_epoch( regressor=target_regressor, dataloader=t_train_dataloader, device=kwargs['device'], history=t_target_regression_eval_train_history) target_regression_eval_val_history = evaluate_target_regression_epoch( regressor=target_regressor, dataloader=val_dataloader, device=kwargs['device'], history=target_regression_eval_val_history) target_regression_eval_test_history = evaluate_target_regression_epoch( regressor=target_regressor, dataloader=test_dataloader, device=kwargs['device'], history=target_regression_eval_test_history) save_flag, stop_flag = model_save_check( history=target_regression_eval_val_history, metric_name=metric_name, tolerance_count=50) if save_flag: torch.save( target_regressor.state_dict(), os.path.join(kwargs['model_save_folder'], f'dann_regressor_{seed}.pt')) if stop_flag: break target_regressor.load_state_dict( torch.load( os.path.join(kwargs['model_save_folder'], f'dann_regressor_{seed}.pt'))) # evaluate_target_regression_epoch(regressor=target_regressor, # dataloader=val_dataloader, # device=kwargs['device'], # history=None, # seed=seed, # output_folder=kwargs['model_save_folder']) evaluate_target_regression_epoch(regressor=target_regressor, dataloader=test_dataloader, device=kwargs['device'], history=None, seed=seed, output_folder=kwargs['model_save_folder']) return target_regressor, (train_history, s_target_regression_eval_train_history, t_target_regression_eval_train_history, target_regression_eval_val_history, target_regression_eval_test_history)
def fine_tune_encoder(train_dataloader, val_dataloader, seed, test_dataloader=None, metric_name='cpearsonr', normalize_flag=False, **kwargs): autoencoder = AE(input_dim=kwargs['input_dim'], latent_dim=kwargs['latent_dim'], hidden_dims=kwargs['encoder_hidden_dims'], dop=kwargs['dop']).to(kwargs['device']) encoder = autoencoder.encoder target_decoder = MoMLP(input_dim=kwargs['latent_dim'], output_dim=kwargs['output_dim'], hidden_dims=kwargs['regressor_hidden_dims'], out_fn=torch.nn.Sigmoid).to(kwargs['device']) target_regressor = EncoderDecoder(encoder=encoder, decoder=target_decoder, normalize_flag=normalize_flag).to(kwargs['device']) target_regression_train_history = defaultdict(list) target_regression_eval_train_history = defaultdict(list) target_regression_eval_val_history = defaultdict(list) target_regression_eval_test_history = defaultdict(list) target_regression_optimizer = torch.optim.AdamW(target_regressor.parameters(), lr=kwargs['lr']) for epoch in range(kwargs['train_num_epochs']): if epoch % 10 == 0: print(f'MLP fine-tuning epoch {epoch}') for step, batch in enumerate(train_dataloader): target_regression_train_history = regression_train_step(model=target_regressor, batch=batch, device=kwargs['device'], optimizer=target_regression_optimizer, history=target_regression_train_history) target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor, dataloader=train_dataloader, device=kwargs['device'], history=target_regression_eval_train_history) target_regression_eval_val_history = evaluate_target_regression_epoch(regressor=target_regressor, dataloader=val_dataloader, device=kwargs['device'], history=target_regression_eval_val_history) if test_dataloader is not None: target_regression_eval_test_history = evaluate_target_regression_epoch(regressor=target_regressor, dataloader=test_dataloader, device=kwargs['device'], history=target_regression_eval_test_history) save_flag, stop_flag = model_save_check(history=target_regression_eval_val_history, metric_name=metric_name, tolerance_count=50) if save_flag or epoch == 0: torch.save(target_regressor.state_dict(), os.path.join(kwargs['model_save_folder'], f'target_regressor_{seed}.pt')) torch.save(target_regressor.encoder.state_dict(), os.path.join(kwargs['model_save_folder'], f'ft_encoder_{seed}.pt')) if stop_flag: break target_regressor.load_state_dict( torch.load(os.path.join(kwargs['model_save_folder'], f'target_regressor_{seed}.pt'))) evaluate_target_regression_epoch(regressor=target_regressor, dataloader=val_dataloader, device=kwargs['device'], history=None, seed=seed, cv_flag=True, output_folder=kwargs['model_save_folder']) if test_dataloader is not None: evaluate_target_regression_epoch(regressor=target_regressor, dataloader=test_dataloader, device=kwargs['device'], history=None, seed=seed, output_folder=kwargs['model_save_folder']) return target_regressor, (target_regression_train_history, target_regression_eval_train_history, target_regression_eval_val_history, target_regression_eval_test_history)
def train_adda(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader, metric_name, seed, **kwargs): """ :param s_dataloaders: :param t_dataloaders: :param kwargs: :return: """ s_train_dataloader = s_dataloaders t_train_dataloader = t_dataloaders autoencoder = AE(input_dim=kwargs['input_dim'], latent_dim=kwargs['latent_dim'], hidden_dims=kwargs['encoder_hidden_dims'], dop=kwargs['dop']).to(kwargs['device']) encoder = autoencoder.encoder target_decoder = MoMLP(input_dim=kwargs['latent_dim'], output_dim=kwargs['output_dim'], hidden_dims=kwargs['regressor_hidden_dims'], out_fn=torch.nn.Sigmoid).to(kwargs['device']) target_regressor = EncoderDecoder( encoder=encoder, decoder=target_decoder).to(kwargs['device']) confounding_classifier = MLP(input_dim=kwargs['latent_dim'], output_dim=1, hidden_dims=kwargs['classifier_hidden_dims'], dop=kwargs['dop']).to(kwargs['device']) critic_train_history = defaultdict(list) gen_train_history = defaultdict(list) s_target_regression_eval_train_history = defaultdict(list) t_target_regression_eval_train_history = defaultdict(list) target_regression_eval_val_history = defaultdict(list) target_regression_eval_test_history = defaultdict(list) model_optimizer = torch.optim.AdamW(target_regressor.parameters(), lr=kwargs['lr']) classifier_optimizer = torch.optim.RMSprop( confounding_classifier.parameters(), lr=kwargs['lr']) for epoch in range(int(kwargs['train_num_epochs'])): if epoch % 50 == 0: print(f'ADDA training epoch {epoch}') for step, s_batch in enumerate(s_train_dataloader): t_batch = next(iter(t_train_dataloader)) critic_train_history = critic_train_step( critic=confounding_classifier, model=target_regressor, s_batch=s_batch, t_batch=t_batch, device=kwargs['device'], optimizer=classifier_optimizer, history=critic_train_history, # clip=0.1, gp=10.0) if (step + 1) % 5 == 0: gen_train_history = gan_gen_train_step( critic=confounding_classifier, model=target_regressor, s_batch=s_batch, t_batch=t_batch, device=kwargs['device'], optimizer=model_optimizer, alpha=1.0, history=gen_train_history) s_target_regression_eval_train_history = evaluate_target_regression_epoch( regressor=target_regressor, dataloader=s_train_dataloader, device=kwargs['device'], history=s_target_regression_eval_train_history) t_target_regression_eval_train_history = evaluate_target_regression_epoch( regressor=target_regressor, dataloader=t_train_dataloader, device=kwargs['device'], history=t_target_regression_eval_train_history) target_regression_eval_val_history = evaluate_target_regression_epoch( regressor=target_regressor, dataloader=val_dataloader, device=kwargs['device'], history=target_regression_eval_val_history) target_regression_eval_test_history = evaluate_target_regression_epoch( regressor=target_regressor, dataloader=test_dataloader, device=kwargs['device'], history=target_regression_eval_test_history) save_flag, stop_flag = model_save_check( history=target_regression_eval_val_history, metric_name=metric_name, tolerance_count=50) if save_flag: torch.save( target_regressor.state_dict(), os.path.join(kwargs['model_save_folder'], f'adda_regressor_{seed}.pt')) if stop_flag: break target_regressor.load_state_dict( torch.load( os.path.join(kwargs['model_save_folder'], f'adda_regressor_{seed}.pt'))) # evaluate_target_regression_epoch(regressor=target_regressor, # dataloader=val_dataloader, # device=kwargs['device'], # history=None, # seed=seed, # output_folder=kwargs['model_save_folder']) evaluate_target_regression_epoch(regressor=target_regressor, dataloader=test_dataloader, device=kwargs['device'], history=None, seed=seed, output_folder=kwargs['model_save_folder']) return target_regressor, (critic_train_history, gen_train_history, s_target_regression_eval_train_history, t_target_regression_eval_train_history, target_regression_eval_val_history, target_regression_eval_test_history)