def train_cleitc(dataloader, seed, **kwargs): """ :param s_dataloaders: :param t_dataloaders: :param kwargs: :return: """ autoencoder = AE(input_dim=kwargs['input_dim'], latent_dim=kwargs['latent_dim'], hidden_dims=kwargs['encoder_hidden_dims'], dop=kwargs['dop']).to(kwargs['device']) # get reference encoder aux_ae = deepcopy(autoencoder) aux_ae.encoder.load_state_dict(torch.load(os.path.join('./model_save/ae5000', f'ft_encoder_{seed}.pt'))) print('reference encoder loaded') reference_encoder = aux_ae.encoder # construct transmitter transmitter = MLP(input_dim=kwargs['latent_dim'], output_dim=kwargs['latent_dim'], hidden_dims=[kwargs['latent_dim']]).to(kwargs['device']) ae_eval_train_history = defaultdict(list) ae_eval_test_history = defaultdict(list) if kwargs['retrain_flag']: cleit_params = [ autoencoder.parameters(), transmitter.parameters() ] cleit_optimizer = torch.optim.AdamW(chain(*cleit_params), lr=kwargs['lr']) # start autoencoder pretraining for epoch in range(int(kwargs['train_num_epochs'])): if epoch % 1 == 0: print(f'----Autoencoder Training Epoch {epoch} ----') for step, batch in enumerate(dataloader): ae_eval_train_history = cleit_train_step(ae=autoencoder, reference_encoder=reference_encoder, transmitter=transmitter, batch=batch, device=kwargs['device'], optimizer=cleit_optimizer, history=ae_eval_train_history) torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt')) torch.save(transmitter.state_dict(), os.path.join(kwargs['model_save_folder'], 'transmitter.pt')) else: try: autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt'))) transmitter.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'transmitter.pt'))) except FileNotFoundError: raise Exception("No pre-trained encoder") encoder = EncoderDecoder(encoder=autoencoder.encoder, decoder=transmitter).to(kwargs['device']) return encoder, (ae_eval_train_history, ae_eval_test_history)
def train_ae(dataloader, **kwargs): """ :param s_dataloaders: :param t_dataloaders: :param kwargs: :return: """ autoencoder = AE(input_dim=kwargs['input_dim'], latent_dim=kwargs['latent_dim'], hidden_dims=kwargs['encoder_hidden_dims'], dop=kwargs['dop']).to(kwargs['device']) ae_eval_train_history = defaultdict(list) ae_eval_test_history = defaultdict(list) if kwargs['retrain_flag']: ae_optimizer = torch.optim.AdamW(autoencoder.parameters(), lr=kwargs['lr']) # start autoencoder pretraining for epoch in range(int(kwargs['train_num_epochs'])): if epoch % 50 == 0: print(f'----Autoencoder Training Epoch {epoch} ----') for step, batch in enumerate(dataloader): ae_eval_train_history = ae_train_step( ae=autoencoder, batch=batch, device=kwargs['device'], optimizer=ae_optimizer, history=ae_eval_train_history) torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'ae.pt')) else: try: autoencoder.load_state_dict( torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt'))) except FileNotFoundError: raise Exception("No pre-trained encoder") return autoencoder.encoder, (ae_eval_train_history, ae_eval_test_history)
def train_cleita(dataloader, seed, **kwargs): autoencoder = AE(input_dim=kwargs['input_dim'], latent_dim=kwargs['latent_dim'], hidden_dims=kwargs['encoder_hidden_dims'], dop=kwargs['dop']).to(kwargs['device']) # get reference encoder aux_ae = deepcopy(autoencoder) aux_ae.encoder.load_state_dict( torch.load(os.path.join('./model_save', f'ft_encoder_{seed}.pt'))) print('reference encoder loaded') reference_encoder = aux_ae.encoder # construct transmitter transmitter = MLP(input_dim=kwargs['latent_dim'], output_dim=kwargs['latent_dim'], hidden_dims=[kwargs['latent_dim']]).to(kwargs['device']) confounding_classifier = MLP(input_dim=kwargs['latent_dim'], output_dim=1, hidden_dims=kwargs['classifier_hidden_dims'], dop=kwargs['dop']).to(kwargs['device']) ae_train_history = defaultdict(list) ae_val_history = defaultdict(list) critic_train_history = defaultdict(list) gen_train_history = defaultdict(list) if kwargs['retrain_flag']: cleit_params = [autoencoder.parameters(), transmitter.parameters()] cleit_optimizer = torch.optim.AdamW(chain(*cleit_params), lr=kwargs['lr']) classifier_optimizer = torch.optim.RMSprop( confounding_classifier.parameters(), lr=kwargs['lr']) for epoch in range(int(kwargs['train_num_epochs'])): if epoch % 50 == 0: print(f'confounder wgan training epoch {epoch}') for step, batch in enumerate(dataloader): critic_train_history = critic_train_step( critic=confounding_classifier, ae=autoencoder, reference_encoder=reference_encoder, transmitter=transmitter, batch=batch, device=kwargs['device'], optimizer=classifier_optimizer, history=critic_train_history, # clip=0.1, gp=10.0) if (step + 1) % 5 == 0: gen_train_history = gan_gen_train_step( critic=confounding_classifier, ae=autoencoder, transmitter=transmitter, batch=batch, device=kwargs['device'], optimizer=cleit_optimizer, alpha=1.0, history=gen_train_history) torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt')) torch.save(transmitter.state_dict(), os.path.join(kwargs['model_save_folder'], 'transmitter.pt')) else: try: autoencoder.load_state_dict( torch.load( os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt'))) transmitter.load_state_dict( torch.load( os.path.join(kwargs['model_save_folder'], 'transmitter.pt'))) except FileNotFoundError: raise Exception("No pre-trained encoder") encoder = EncoderDecoder(encoder=autoencoder.encoder, decoder=transmitter).to(kwargs['device']) return encoder, (ae_train_history, ae_val_history, critic_train_history, gen_train_history)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") test_transform = transforms.Compose([ transforms.ToTensor(), ]) testset = torchvision.datasets.ImageFolder(val_data_folder, transform=test_transform) test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=20) model = AE(K=K).to(device) model = nn.DataParallel(model, device_ids=[0]) model.load_state_dict( torch.load(saved_model_name, map_location={'cuda:1': 'cuda:0'})) if not os.path.exists(save_folder_name): os.makedirs(save_folder_name) with tqdm(total=len(test_loader), desc="Batches") as pbar: for i, (data) in enumerate(test_loader): model.eval() img, labels = data encoded, out, hashed = model(img) torch.save(out, save_folder_name + "/out/out_{}.pt".format(i)) torch.save(labels, save_folder_name + "/lab/lab_{}.pt".format(i)) torch.save(hashed, save_folder_name + "/hash/hash_{}.pt".format(i)) pbar.update(1)
send_slack_notif("Started with train stage-{}".format(trn_stage_no + 1)) tensorboard_folder = '../runs/ae_hash/ae_stage{}'.format(trn_stage_no + 1) writer = SummaryWriter(tensorboard_folder) trainset = torchvision.datasets.ImageFolder(train_data_folder, transform=train_transform) valset = torchvision.datasets.ImageFolder(val_data_folder, transform=val_transform) train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=20) val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=20) model = AE(K=K).to(device) model = nn.DataParallel(model, device_ids=[0, 1]) if train_stage["use_weight"]: model.load_state_dict(torch.load(saved_model_name + '_stage1.pt'), strict=False) # Adding layer parameters for different (10x faster than pretrained) learning rate fast_learning_layers = ['hashed_layer.{}'.format(ii) for ii in [0, 2, 4, 6]] fast_learning_layers = ['module.' + s + sb for s in fast_learning_layers for sb in ['.weight', '.bias']] params = list(map(lambda x: x[1], list(filter(lambda kv: kv[0] in fast_learning_layers, model.named_parameters())))) base_params = list(map(lambda x: x[1], list(filter(lambda kv: kv[0] not in fast_learning_layers, model.named_parameters())))) assert len(params) == len(fast_learning_layers) # Initializing losses and adding loss params to optimizer with higher lr criterion1 = comp_loss() criterion2 = cauchy_loss(K=K, q_lambda=q_lambda) params.extend(list(criterion1.parameters())) params.extend(list(criterion2.parameters()))