def main( model, time_stamp, device, ally_classes, advr_1_classes, advr_2_classes, encoding_dim, hidden_dim, leaky, test_size, batch_size, n_epochs, shuffle, lr_ally, lr_advr_1, lr_advr_2, expt, pca_ckpt, autoencoder_ckpt, encoder_ckpt, ): device = torch_device(device=device) X, y_ally, y_advr_1, y_advr_2 = load_processed_data( expt, 'processed_data_X_y_ally_y_advr_y_advr_2.pkl') log_shapes([X, y_ally, y_advr_1, y_advr_2], locals(), 'Dataset loaded') X_train, X_valid, \ y_ally_train, y_ally_valid, \ y_advr_1_train, y_advr_1_valid, \ y_advr_2_train, y_advr_2_valid = train_test_split( X, y_ally, y_advr_1, y_advr_2, test_size=test_size, stratify=pd.DataFrame(np.concatenate( ( y_ally.reshape(-1, ally_classes), y_advr_1.reshape(-1, advr_1_classes), y_advr_2.reshape(-1, advr_2_classes), ), axis=1) ) ) log_shapes([ X_train, X_valid, y_ally_train, y_ally_valid, y_advr_1_train, y_advr_1_valid, y_advr_2_train, y_advr_2_valid, ], locals(), 'Data size after train test split') scaler = StandardScaler() X_normalized_train = scaler.fit_transform(X_train) X_normalized_valid = scaler.transform(X_valid) log_shapes([X_normalized_train, X_normalized_valid], locals()) encoder = torch.load(encoder_ckpt) encoder.eval() optim = torch.optim.Adam criterionBCEWithLogits = nn.BCEWithLogitsLoss() criterionCrossEntropy = nn.CrossEntropyLoss() h = { 'epoch': { 'train': [], 'valid': [], }, 'encoder': { 'ally_train': [], 'ally_valid': [], 'advr_1_train': [], 'advr_1_valid': [], 'advr_2_train': [], 'advr_2_valid': [], }, } for _ in ['encoder']: dataset_train = utils.TensorDataset( torch.Tensor(X_normalized_train), torch.Tensor(y_ally_train.reshape(-1, ally_classes)), torch.Tensor(y_advr_1_train.reshape(-1, advr_1_classes)), torch.Tensor(y_advr_2_train.reshape(-1, advr_2_classes)), ) dataset_valid = utils.TensorDataset( torch.Tensor(X_normalized_valid), torch.Tensor(y_ally_valid.reshape(-1, ally_classes)), torch.Tensor(y_advr_1_valid.reshape(-1, advr_1_classes)), torch.Tensor(y_advr_2_valid.reshape(-1, advr_2_classes)), ) def transform(input_arg): return encoder(input_arg) dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=1) dataloader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=shuffle, num_workers=1) ally = DiscriminatorFCN(encoding_dim, hidden_dim, ally_classes, leaky).to(device) advr_1 = DiscriminatorFCN(encoding_dim, hidden_dim, advr_1_classes, leaky).to(device) advr_2 = DiscriminatorFCN(encoding_dim, hidden_dim, advr_2_classes, leaky).to(device) ally.apply(weights_init) advr_1.apply(weights_init) advr_2.apply(weights_init) sep('{}:{}'.format(_, 'ally')) summary(ally, input_size=(1, encoding_dim)) sep('{}:{}'.format(_, 'advr 1')) summary(advr_1, input_size=(1, encoding_dim)) sep('{}:{}'.format(_, 'advr 2')) summary(advr_2, input_size=(1, encoding_dim)) optimizer_ally = optim(ally.parameters(), lr=lr_ally) optimizer_advr_1 = optim(advr_1.parameters(), lr=lr_advr_1) optimizer_advr_2 = optim(advr_2.parameters(), lr=lr_advr_2) # adversary 1 sep("adversary 1") logging.info('{} \t {} \t {}'.format( 'Epoch', 'Advr 1 Train', 'Advr 1 Valid', )) for epoch in range(n_epochs): advr_1.train() nsamples = 0 iloss_advr = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = transform(data[0].to(device)) y_advr_train_torch = data[2].to(device) optimizer_advr_1.zero_grad() y_advr_train_hat_torch = advr_1(X_train_torch) loss_advr = criterionBCEWithLogits(y_advr_train_hat_torch, y_advr_train_torch) loss_advr.backward() optimizer_advr_1.step() nsamples += 1 iloss_advr += loss_advr.item() h[_]['advr_1_train'].append(iloss_advr / nsamples) if epoch % int(n_epochs / 10) != 0: continue advr_1.eval() nsamples = 0 iloss_advr = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = transform(data[0].to(device)) y_advr_valid_torch = data[2].to(device) y_advr_valid_hat_torch = advr_1(X_valid_torch) valid_loss_advr = criterionBCEWithLogits( y_advr_valid_hat_torch, y_advr_valid_torch, ) predicted = y_advr_valid_hat_torch > 0.5 nsamples += 1 iloss_advr += valid_loss_advr.item() total += y_advr_valid_torch.size(0) correct += (predicted == y_advr_valid_torch).sum().item() h[_]['advr_1_valid'].append(iloss_advr / nsamples) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[_]['advr_1_train'][-1], h[_]['advr_1_valid'][-1], correct / total)) # adversary sep("adversary 2") logging.info('{} \t {} \t {}'.format( 'Epoch', 'Advr 2 Train', 'Advr 2 Valid', )) for epoch in range(n_epochs): advr_2.train() nsamples = 0 iloss_advr = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = transform(data[0].to(device)) y_advr_train_torch = data[3].to(device) optimizer_advr_2.zero_grad() y_advr_train_hat_torch = advr_2(X_train_torch) loss_advr = criterionBCEWithLogits(y_advr_train_hat_torch, y_advr_train_torch) loss_advr.backward() optimizer_advr_2.step() nsamples += 1 iloss_advr += loss_advr.item() h[_]['advr_2_train'].append(iloss_advr / nsamples) if epoch % int(n_epochs / 10) != 0: continue advr_2.eval() nsamples = 0 iloss_advr = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = transform(data[0].to(device)) y_advr_valid_torch = data[3].to(device) y_advr_valid_hat_torch = advr_2(X_valid_torch) valid_loss_advr = criterionBCEWithLogits( y_advr_valid_hat_torch, y_advr_valid_torch) predicted = y_advr_valid_hat_torch > 0.5 nsamples += 1 iloss_advr += valid_loss_advr.item() total += y_advr_valid_torch.size(0) correct += (predicted == y_advr_valid_torch).sum().item() h[_]['advr_2_valid'].append(iloss_advr / nsamples) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[_]['advr_2_train'][-1], h[_]['advr_2_valid'][-1], correct / total)) sep("ally") logging.info('{} \t {} \t {}'.format( 'Epoch', 'Ally Train', 'Ally Valid', )) for epoch in range(n_epochs): ally.train() nsamples = 0 iloss_ally = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = transform(data[0].to(device)) y_ally_train_torch = data[1].to(device) optimizer_ally.zero_grad() y_ally_train_hat_torch = ally(X_train_torch) loss_ally = criterionBCEWithLogits(y_ally_train_hat_torch, y_ally_train_torch) loss_ally.backward() optimizer_ally.step() nsamples += 1 iloss_ally += loss_ally.item() if epoch not in h['epoch']['train']: h['epoch']['train'].append(epoch) h[_]['ally_train'].append(iloss_ally / nsamples) if epoch % int(n_epochs / 10) != 0: continue ally.eval() nsamples = 0 iloss_ally = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = transform(data[0].to(device)) y_ally_valid_torch = data[1].to(device) y_ally_valid_hat_torch = ally(X_valid_torch) valid_loss_ally = criterionBCEWithLogits( y_ally_valid_hat_torch, y_ally_valid_torch) predicted = y_ally_valid_hat_torch > 0.5 nsamples += 1 iloss_ally += valid_loss_ally.item() total += y_ally_valid_torch.size(0) correct += (predicted == y_ally_valid_torch).sum().item() if epoch not in h['epoch']['valid']: h['epoch']['valid'].append(epoch) h[_]['ally_valid'].append(iloss_ally / nsamples) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[_]['ally_train'][-1], h[_]['ally_valid'][-1], correct / total)) checkpoint_location = \ 'checkpoints/{}/{}_training_history_{}.pkl'.format( expt, model, time_stamp) sep() logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(h, open(checkpoint_location, 'wb'))
def main( model, time_stamp, device, ngpu, ally_classes, advr_1_classes, advr_2_classes, encoding_dim, hidden_dim, leaky, activation, test_size, batch_size, n_epochs, shuffle, init_weight, lr_encd, lr_ally, lr_advr_1, lr_advr_2, alpha, g_reps, d_reps, expt, marker ): device = torch_device(device=device) X, y_ally, y_advr_1, y_advr_2 = load_processed_data( expt, 'processed_data_X_y_ally_y_advr_y_advr_2.pkl') log_shapes( [X, y_ally, y_advr_1, y_advr_2], locals(), 'Dataset loaded' ) X_train, X_valid, \ y_ally_train, y_ally_valid, \ y_advr_1_train, y_advr_1_valid, \ y_advr_2_train, y_advr_2_valid = train_test_split( X, y_ally, y_advr_1, y_advr_2, test_size=test_size, stratify=pd.DataFrame(np.concatenate( ( y_ally.reshape(-1, ally_classes), y_advr_1.reshape(-1, advr_1_classes), y_advr_2.reshape(-1, advr_2_classes), ), axis=1) ) ) log_shapes( [ X_train, X_valid, y_ally_train, y_ally_valid, y_advr_1_train, y_advr_1_valid, y_advr_2_train, y_advr_2_valid, ], locals(), 'Data size after train test split' ) scaler = StandardScaler() X_normalized_train = scaler.fit_transform(X_train) X_normalized_valid = scaler.transform(X_valid) log_shapes([X_normalized_train, X_normalized_valid], locals()) encoder = GeneratorFCN( X_normalized_train.shape[1], hidden_dim, encoding_dim, leaky, activation).to(device) ally = DiscriminatorFCN( encoding_dim, hidden_dim, ally_classes, leaky).to(device) advr_1 = DiscriminatorFCN( encoding_dim, hidden_dim, advr_1_classes, leaky).to(device) advr_2 = DiscriminatorFCN( encoding_dim, hidden_dim, advr_2_classes, leaky).to(device) if init_weight: sep() logging.info('applying weights_init ...') encoder.apply(weights_init) ally.apply(weights_init) advr_1.apply(weights_init) advr_2.apply(weights_init) sep('encoder') summary(encoder, input_size=(1, X_normalized_train.shape[1])) sep('ally') summary(ally, input_size=(1, encoding_dim)) sep('advr_1') summary(advr_1, input_size=(1, encoding_dim)) sep('advr_2') summary(advr_2, input_size=(1, encoding_dim)) optim = torch.optim.Adam criterionBCEWithLogits = nn.BCEWithLogitsLoss() criterionCrossEntropy = nn.CrossEntropyLoss() optimizer_encd = optim( encoder.parameters(), lr=lr_encd, weight_decay=lr_encd ) optimizer_ally = optim( ally.parameters(), lr=lr_ally, weight_decay=lr_ally ) optimizer_advr_1 = optim( advr_1.parameters(), lr=lr_advr_1, weight_decay=lr_advr_1 ) optimizer_advr_2 = optim( advr_2.parameters(), lr=lr_advr_2, weight_decay=lr_advr_2 ) dataset_train = utils.TensorDataset( torch.Tensor(X_normalized_train), torch.Tensor(y_ally_train.reshape(-1, 1)), torch.Tensor(y_advr_1_train.reshape(-1, 1)), torch.Tensor(y_advr_2_train.reshape(-1, 3)), ) dataloader_train = torch.utils.data.DataLoader( dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=1 ) dataset_valid = utils.TensorDataset( torch.Tensor(X_normalized_valid), torch.Tensor(y_ally_valid.reshape(-1, 1)), torch.Tensor(y_advr_1_valid.reshape(-1, 1)), torch.Tensor(y_advr_2_valid.reshape(-1, 3)), ) dataloader_valid = torch.utils.data.DataLoader( dataset_valid, batch_size=batch_size, shuffle=shuffle, num_workers=1 ) epochs_train = [] epochs_valid = [] encd_loss_train = [] encd_loss_valid = [] ally_loss_train = [] ally_loss_valid = [] advr_1_loss_train = [] advr_1_loss_valid = [] advr_2_loss_train = [] advr_2_loss_valid = [] logging.info('{} \t {} \t {} \t {} \t {} \t {} \t {} \t {} \t {}'.format( 'Epoch', 'Encd Train', 'Encd Valid', 'Ally Train', 'Ally Valid', 'Advr 1 Train', 'Advr 1 Valid', 'Advr 2 Train', 'Advr 2 Valid', )) for epoch in range(n_epochs): encoder.train() ally.eval() advr_1.eval() advr_2.eval() for __ in range(g_reps): nsamples = 0 iloss = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_ally_train_torch = data[1].to(device) y_advr_1_train_torch = data[2].to(device) y_advr_2_train_torch = data[3].to(device) optimizer_encd.zero_grad() # Forward pass X_train_encoded = encoder(X_train_torch) y_ally_train_hat_torch = ally(X_train_encoded) y_advr_1_train_hat_torch = advr_1(X_train_encoded) y_advr_2_train_hat_torch = advr_2(X_train_encoded) # Compute Loss loss_ally = criterionBCEWithLogits( y_ally_train_hat_torch, y_ally_train_torch) loss_advr_1 = criterionBCEWithLogits( y_advr_1_train_hat_torch, y_advr_1_train_torch) loss_advr_2 = criterionCrossEntropy( y_advr_2_train_hat_torch, torch.argmax(y_advr_2_train_torch, 1)) loss_encd = loss_ally - loss_advr_1 - loss_advr_2 # Backward pass loss_encd.backward() optimizer_encd.step() nsamples += 1 iloss += loss_encd.item() epochs_train.append(epoch) encd_loss_train.append(iloss/nsamples) encoder.eval() ally.train() advr_1.train() advr_2.train() for __ in range(d_reps): nsamples = 0 iloss_ally = 0 iloss_advr_1 = 0 iloss_advr_2 = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_ally_train_torch = data[1].to(device) y_advr_1_train_torch = data[2].to(device) y_advr_2_train_torch = data[3].to(device) optimizer_ally.zero_grad() X_train_encoded = encoder(X_train_torch) y_ally_train_hat_torch = ally(X_train_encoded) loss_ally = criterionBCEWithLogits( y_ally_train_hat_torch, y_ally_train_torch) loss_ally.backward() optimizer_ally.step() optimizer_advr_1.zero_grad() X_train_encoded = encoder(X_train_torch) y_advr_1_train_hat_torch = advr_1(X_train_encoded) loss_advr_1 = criterionBCEWithLogits( y_advr_1_train_hat_torch, y_advr_1_train_torch) loss_advr_1.backward() optimizer_advr_1.step() optimizer_advr_2.zero_grad() X_train_encoded = encoder(X_train_torch) y_advr_2_train_hat_torch = advr_2(X_train_encoded) loss_advr_2 = criterionCrossEntropy( y_advr_2_train_hat_torch, torch.argmax(y_advr_2_train_torch, 1)) loss_advr_2.backward() optimizer_advr_2.step() nsamples += 1 iloss_ally += loss_ally.item() iloss_advr_1 += loss_advr_1.item() iloss_advr_2 += loss_advr_2.item() ally_loss_train.append(iloss_ally/nsamples) advr_1_loss_train.append(iloss_advr_1/nsamples) advr_2_loss_train.append(iloss_advr_2/nsamples) if epoch % int(n_epochs/10) != 0: continue encoder.eval() ally.eval() advr_1.eval() advr_2.eval() nsamples = 0 iloss = 0 iloss_ally = 0 iloss_advr_1 = 0 iloss_advr_2 = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = data[0].to(device) y_ally_valid_torch = data[1].to(device) y_advr_1_valid_torch = data[2].to(device) y_advr_2_valid_torch = data[3].to(device) X_valid_encoded = encoder(X_valid_torch) y_ally_valid_hat_torch = ally(X_valid_encoded) y_advr_1_valid_hat_torch = advr_1(X_valid_encoded) y_advr_2_valid_hat_torch = advr_2(X_valid_encoded) valid_loss_ally = criterionBCEWithLogits( y_ally_valid_hat_torch, y_ally_valid_torch) valid_loss_advr_1 = criterionBCEWithLogits( y_advr_1_valid_hat_torch, y_advr_1_valid_torch) valid_loss_advr_2 = criterionCrossEntropy( y_advr_2_valid_hat_torch, torch.argmax(y_advr_2_valid_torch, 1)) valid_loss_encd = valid_loss_ally - valid_loss_advr_1 - \ valid_loss_advr_2 nsamples += 1 iloss += valid_loss_encd.item() iloss_ally += valid_loss_ally.item() iloss_advr_1 += valid_loss_advr_1.item() iloss_advr_2 += valid_loss_advr_2.item() epochs_valid.append(epoch) encd_loss_valid.append(iloss/nsamples) ally_loss_valid.append(iloss_ally/nsamples) advr_1_loss_valid.append(iloss_advr_1/nsamples) advr_2_loss_valid.append(iloss_advr_2/nsamples) logging.info( '{} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f}'. format( epoch, encd_loss_train[-1], encd_loss_valid[-1], ally_loss_train[-1], ally_loss_valid[-1], advr_1_loss_train[-1], advr_1_loss_valid[-1], advr_2_loss_train[-1], advr_2_loss_valid[-1], )) config_summary = '{}_device_{}_dim_{}_hidden_{}_batch_{}_epochs_{}_lrencd_{}_lrally_{}_tr_{:.4f}_val_{:.4f}'\ .format( marker, device, encoding_dim, hidden_dim, batch_size, n_epochs, lr_encd, lr_ally, encd_loss_train[-1], advr_1_loss_valid[-1], ) plt.plot(epochs_train, encd_loss_train, 'r') plt.plot(epochs_valid, encd_loss_valid, 'r--') plt.plot(epochs_train, ally_loss_train, 'b') plt.plot(epochs_valid, ally_loss_valid, 'b--') plt.plot(epochs_train, advr_1_loss_train, 'g') plt.plot(epochs_valid, advr_1_loss_valid, 'g--') plt.plot(epochs_train, advr_2_loss_train, 'y') plt.plot(epochs_valid, advr_2_loss_valid, 'y--') plt.legend([ 'encoder train', 'encoder valid', 'ally train', 'ally valid', 'advr 1 train', 'advr 1 valid', 'advr 2 train', 'advr 2 valid', ]) plt.title("{} on {} training".format(model, expt)) plot_location = 'plots/{}/{}_training_{}_{}.png'.format( expt, model, time_stamp, config_summary) sep() logging.info('Saving: {}'.format(plot_location)) plt.savefig(plot_location) checkpoint_location = \ 'checkpoints/{}/{}_training_history_{}_{}.pkl'.format( expt, model, time_stamp, config_summary) logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(( epochs_train, epochs_valid, encd_loss_train, encd_loss_valid, ally_loss_train, ally_loss_valid, advr_1_loss_train, advr_1_loss_valid, advr_2_loss_train, advr_2_loss_valid, ), open(checkpoint_location, 'wb')) model_ckpt = 'checkpoints/{}/{}_torch_model_{}_{}.pkl'.format( expt, model, time_stamp, config_summary) logging.info('Saving: {}'.format(model_ckpt)) torch.save(encoder, model_ckpt)
def main( model, time_stamp, device, encoding_dim, hidden_dim, leaky, test_size, batch_size, n_epochs, shuffle, lr, expt, pca_ckpt, autoencoder_ckpt, encoder_ckpt, ): device = torch_device(device=device) X, targets = load_processed_data( expt, 'processed_data_X_targets.pkl') log_shapes( [X] + [targets[i] for i in targets], locals(), 'Dataset loaded' ) targets = {i: elem.reshape(-1, 1) for i, elem in targets.items()} X_train, X_valid, \ y_adt_train, y_adt_valid = train_test_split( X, targets['admission_type'], test_size=test_size, stratify=pd.DataFrame(np.concatenate( ( targets['admission_type'], ), axis=1) ) ) log_shapes( [ X_train, X_valid, y_adt_train, y_adt_valid, ], locals(), 'Data size after train test split' ) y_train = y_adt_train y_valid = y_adt_valid scaler = StandardScaler() X_normalized_train = scaler.fit_transform(X_train) X_normalized_valid = scaler.transform(X_valid) log_shapes([X_normalized_train, X_normalized_valid], locals()) ckpts = { # 123A: checkpoints/mimic/n_ind_gan_training_history_02_03_2020_17_41_09.pkl # 0: 'checkpoints/mimic/n_eigan_torch_model_02_03_2020_16_13_39_A_n_1_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_0_encd_0.0471_advr_0.5991.pkl', # 28B: checkpoints/mimic/n_ind_gan_training_history_02_05_2020_00_23_29.pkl # 1: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_22_51_11_B_n_2_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_1_encd_0.0475_advr_0.5992.pkl', # 123A: checkpoints/mimic/n_ind_gan_training_history_02_03_2020_17_41_09.pkl # 2: 'checkpoints/mimic/n_eigan_torch_model_02_03_2020_16_14_37_A_n_3_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_2_encd_0.0464_advr_0.5991.pkl', # 224A: checkpoints/mimic/n_ind_gan_training_history_02_03_2020_20_09_39.pkl # 3: 'checkpoints/mimic/n_eigan_torch_model_02_03_2020_18_08_09_A_n_4_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_3_encd_0.0469_advr_0.5991.pkl', # 24A: checkpoints/mimic/n_ind_gan_training_history_02_04_2020_00_21_50.pkl # 4: 'checkpoints/mimic/n_eigan_torch_model_02_03_2020_23_12_05_A_n_5_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_4_encd_0.0468_advr_0.5994.pkl', # 67A: checkpoints/mimic/n_ind_gan_training_history_02_04_2020_05_30_09.pkl # 5: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_00_15_28_A_n_6_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_5_encd_0.0462_advr_0.5991.pkl', # 67A: checkpoints/mimic/n_ind_gan_training_history_02_04_2020_05_30_09.pkl # 6: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_00_42_31_A_n_7_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_6_encd_0.0453_advr_0.5992.pkl', # 28B: checkpoints/mimic/n_ind_gan_training_history_02_05_2020_00_23_29.pkl # 7: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_22_54_21_B_n_8_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_7_encd_0.0477_advr_0.5992.pkl', # 9B: checkpoints/mimic/n_ind_gan_training_history_02_05_2020_18_09_03.pkl # 8: 'checkpoints/mimic/n_eigan_torch_model_02_05_2020_00_48_34_B_n_9_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_8_encd_0.0473_advr_0.6000.pkl', # nA: heckpoints/mimic/n_ind_gan_training_history_02_04_2020_20_13_29.pkl # 9: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_18_51_01_A_n_10_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_9_encd_0.0420_advr_0.5992.pkl', } h = {} for idx, ckpt in ckpts.items(): encoder = torch.load(ckpt, map_location=device) encoder.eval() optim = torch.optim.Adam criterionBCEWithLogits = nn.BCEWithLogitsLoss() h[idx] = { 'epoch_train': [], 'epoch_valid': [], 'advr_train': [], 'advr_valid': [], } dataset_train = utils.TensorDataset( torch.Tensor(X_normalized_train), torch.Tensor(y_train), ) dataset_valid = utils.TensorDataset( torch.Tensor(X_normalized_valid), torch.Tensor(y_valid), ) def transform(input_arg): return encoder(input_arg) dataloader_train = torch.utils.data.DataLoader( dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=1 ) dataloader_valid = torch.utils.data.DataLoader( dataset_valid, batch_size=batch_size, shuffle=shuffle, num_workers=1 ) clf = DiscriminatorFCN( encoding_dim, hidden_dim, 1, leaky).to(device) clf.apply(weights_init) sep('{} {}'.format(idx+1, 'ally')) summary(clf, input_size=(1, encoding_dim)) optimizer = optim(clf.parameters(), lr=lr) # adversary 1 sep("adversary with {} ally encoder".format(idx+1)) logging.info('{} \t {} \t {}'.format( 'Epoch', 'Advr Train', 'Advr Valid', )) for epoch in range(n_epochs): clf.train() nsamples = 0 iloss_advr = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = transform(data[0].to(device)) y_advr_train_torch = data[1].to(device) optimizer.zero_grad() y_advr_train_hat_torch = clf(X_train_torch) loss_advr = criterionBCEWithLogits( y_advr_train_hat_torch, y_advr_train_torch) loss_advr.backward() optimizer.step() nsamples += 1 iloss_advr += loss_advr.item() h[idx]['advr_train'].append(iloss_advr/nsamples) h[idx]['epoch_train'].append(epoch) if epoch % int(n_epochs/10) != 0: continue clf.eval() nsamples = 0 iloss_advr = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = transform(data[0].to(device)) y_advr_valid_torch = data[1].to(device) y_advr_valid_hat_torch = clf(X_valid_torch) valid_loss_advr = criterionBCEWithLogits( y_advr_valid_hat_torch, y_advr_valid_torch,) predicted = y_advr_valid_hat_torch > 0.5 nsamples += 1 iloss_advr += valid_loss_advr.item() total += y_advr_valid_torch.size(0) correct += (predicted == y_advr_valid_torch).sum().item() h[idx]['advr_valid'].append(iloss_advr/nsamples) h[idx]['epoch_valid'].append(epoch) logging.info( '{} \t {:.8f} \t {:.8f} \t {:.8f}'. format( epoch, h[idx]['advr_train'][-1], h[idx]['advr_valid'][-1], correct/total )) checkpoint_location = \ 'checkpoints/{}/{}_training_history_{}.pkl'.format( expt, model, time_stamp) sep() logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(h, open(checkpoint_location, 'wb'))
def main( model, time_stamp, device, ally_classes, advr_1_classes, advr_2_classes, encoding_dim, hidden_dim, leaky, test_size, batch_size, n_epochs, shuffle, lr, expt, pca_ckpt, autoencoder_ckpt, encoder_ckpt, ): device = torch_device(device=device) X, targets = load_processed_data(expt, 'processed_data_X_targets.pkl') log_shapes([X] + [targets[i] for i in targets], locals(), 'Dataset loaded') h = {} for name, target in targets.items(): sep(name) target = target.reshape(-1, 1) X_train, X_valid, \ y_train, y_valid = train_test_split( X, target, test_size=test_size, stratify=target ) log_shapes([ X_train, X_valid, y_train, y_valid, ], locals(), 'Data size after train test split') scaler = StandardScaler() X_normalized_train = scaler.fit_transform(X_train) X_normalized_valid = scaler.transform(X_valid) log_shapes([X_normalized_train, X_normalized_valid], locals()) optim = torch.optim.Adam criterionBCEWithLogits = nn.BCEWithLogitsLoss() h[name] = { 'epoch_train': [], 'epoch_valid': [], 'y_train': [], 'y_valid': [], } dataset_train = utils.TensorDataset( torch.Tensor(X_normalized_train), torch.Tensor(y_train.reshape(-1, 1))) dataset_valid = utils.TensorDataset( torch.Tensor(X_normalized_valid), torch.Tensor(y_valid.reshape(-1, 1)), ) dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=1) dataloader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=shuffle, num_workers=1) clf = DiscriminatorFCN(encoding_dim, hidden_dim, 1, leaky).to(device) clf.apply(weights_init) sep('{}:{}'.format(name, 'summary')) summary(clf, input_size=(1, encoding_dim)) optimizer = optim(clf.parameters(), lr=lr) # adversary 1 sep("TRAINING") logging.info('{} \t {} \t {}'.format( 'Epoch', 'Train', 'Valid', )) for epoch in range(n_epochs): clf.train() nsamples = 0 iloss_train = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_train_torch = data[1].to(device) optimizer.zero_grad() y_train_hat_torch = clf(X_train_torch) loss_train = criterionBCEWithLogits(y_train_hat_torch, y_train_torch) loss_train.backward() optimizer.step() nsamples += 1 iloss_train += loss_train.item() h[name]['y_train'].append(iloss_train / nsamples) h[name]['epoch_train'].append(epoch) if epoch % int(n_epochs / 10) != 0: continue clf.eval() nsamples = 0 iloss_valid = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = data[0].to(device) y_valid_torch = data[1].to(device) y_valid_hat_torch = clf(X_valid_torch) valid_loss = criterionBCEWithLogits( y_valid_hat_torch, y_valid_torch, ) predicted = y_valid_hat_torch > 0.5 nsamples += 1 iloss_valid += valid_loss.item() total += y_valid_torch.size(0) correct += (predicted == y_valid_torch).sum().item() h[name]['y_valid'].append(iloss_valid / nsamples) h[name]['epoch_valid'].append(epoch) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[name]['y_train'][-1], h[name]['y_valid'][-1], correct / total)) checkpoint_location = \ 'checkpoints/{}/{}_training_history_{}.pkl'.format( expt, model, time_stamp) sep() logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(h, open(checkpoint_location, 'wb'))
def main( model, time_stamp, device, ngpu, encoding_dim, hidden_dim, leaky, activation, test_size, batch_size, n_epochs, shuffle, init_weight, lr_encd, lr_ally, lr_advr, alpha, expt, num_allies, marker ): device = torch_device(device=device) X, targets = load_processed_data( expt, 'processed_data_X_targets.pkl') log_shapes( [X] + [targets[i] for i in targets], locals(), 'Dataset loaded' ) targets = {i: elem.reshape(-1, 1) for i, elem in targets.items()} X_train, X_valid, \ y_hef_train, y_hef_valid, \ y_exf_train, y_exf_valid, \ y_gdr_train, y_gdr_valid, \ y_lan_train, y_lan_valid, \ y_mar_train, y_mar_valid, \ y_rel_train, y_rel_valid, \ y_ins_train, y_ins_valid, \ y_dis_train, y_dis_valid, \ y_adl_train, y_adl_valid, \ y_adt_train, y_adt_valid, \ y_etn_train, y_etn_valid = train_test_split( X, targets['hospital_expire_flag'], targets['expire_flag'], targets['gender'], targets['language'], targets['marital_status'], targets['religion'], targets['insurance'], targets['discharge_location'], targets['admission_location'], targets['admission_type'], targets['ethnicity'], test_size=test_size, stratify=pd.DataFrame(np.concatenate( ( targets['admission_type'], ), axis=1) ) ) log_shapes( [ X_train, X_valid, y_hef_train, y_hef_valid, y_exf_train, y_exf_valid, y_gdr_train, y_gdr_valid, y_lan_train, y_lan_valid, y_mar_train, y_mar_valid, y_rel_train, y_rel_valid, y_ins_train, y_ins_valid, y_dis_train, y_dis_valid, y_adl_train, y_adl_valid, y_adt_train, y_adt_valid, y_etn_train, y_etn_valid ], locals(), 'Data size after train test split' ) y_ally_trains = [ y_hef_train, y_exf_train, y_gdr_train, y_lan_train, y_mar_train, y_rel_train, y_ins_train, y_dis_train, y_adl_train, y_etn_train, ] y_ally_valids = [ y_hef_valid, y_exf_valid, y_gdr_valid, y_lan_valid, y_mar_valid, y_rel_valid, y_ins_valid, y_dis_valid, y_adl_valid, y_etn_valid, ] y_advr_train = y_adt_train y_advr_valid = y_adt_valid scaler = StandardScaler() X_normalized_train = scaler.fit_transform(X_train) X_normalized_valid = scaler.transform(X_valid) log_shapes([X_normalized_train, X_normalized_valid], locals()) for i in [num_allies-1]: sep('NUMBER OF ALLIES: {}'.format(i+1)) encoder = GeneratorFCN( X_normalized_train.shape[1], hidden_dim, encoding_dim, leaky, activation).to(device) ally = {} for j in range(i+1): ally[j] = DiscriminatorFCN( encoding_dim, hidden_dim, 1, leaky).to(device) advr = DiscriminatorFCN( encoding_dim, hidden_dim, 1, leaky).to(device) if init_weight: sep() logging.info('applying weights_init ...') encoder.apply(weights_init) for j in range(i+1): ally[j].apply(weights_init) advr.apply(weights_init) sep('encoder') summary(encoder, input_size=(1, X_normalized_train.shape[1])) for j in range(i+1): sep('ally:{}'.format(j)) summary(ally[j], input_size=(1, encoding_dim)) sep('advr') summary(advr, input_size=(1, encoding_dim)) optim = torch.optim.Adam criterionBCEWithLogits = nn.BCEWithLogitsLoss() optimizer_encd = optim(encoder.parameters(), lr=lr_encd) optimizer_ally = {} for j in range(i+1): optimizer_ally[j] = optim(ally[j].parameters(), lr=lr_ally) optimizer_advr = optim(advr.parameters(), lr=lr_advr) dataset_train = utils.TensorDataset( torch.Tensor(X_normalized_train), torch.Tensor(y_advr_train), ) for y_ally_train in y_ally_trains: dataset_train.tensors = (*dataset_train.tensors, torch.Tensor(y_ally_train)) dataloader_train = torch.utils.data.DataLoader( dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=1 ) dataset_valid = utils.TensorDataset( torch.Tensor(X_normalized_valid), torch.Tensor(y_advr_valid), ) for y_ally_valid in y_ally_valids: dataset_valid.tensors = (*dataset_valid.tensors, torch.Tensor(y_ally_valid)) dataloader_valid = torch.utils.data.DataLoader( dataset_valid, batch_size=batch_size, shuffle=shuffle, num_workers=1 ) epochs_train = [] epochs_valid = [] encd_loss_train = [] encd_loss_valid = [] ally_loss_train = {} ally_loss_valid = {} for j in range(i+1): ally_loss_train[j] = [] ally_loss_valid[j] = [] advr_loss_train = [] advr_loss_valid = [] log_list = ['epoch', 'encd_train', 'encd_valid', 'advr_train', 'advr_valid'] + \ ['ally_{}_train \t ally_{}_valid'.format(str(j), str(j)) for j in range(i+1)] logging.info(' \t '.join(log_list)) for epoch in range(n_epochs): encoder.train() for j in range(i+1): ally[i].eval() advr.eval() nsamples = 0 iloss = 0 for data in dataloader_train: X_train_torch = data[0].to(device) y_advr_train_torch = data[1].to(device) y_ally_train_torch = {} for j in range(i+1): y_ally_train_torch[j] = data[j+2].to(device) optimizer_encd.zero_grad() # Forward pass X_train_encoded = encoder(X_train_torch) y_advr_train_hat_torch = advr(X_train_encoded) y_ally_train_hat_torch = {} for j in range(i+1): y_ally_train_hat_torch[j] = ally[j](X_train_encoded) # Compute Loss loss_ally = {} for j in range(i+1): loss_ally[j] = criterionBCEWithLogits( y_ally_train_hat_torch[j], y_ally_train_torch[j]) loss_advr = criterionBCEWithLogits( y_advr_train_hat_torch, y_advr_train_torch) loss_encd = alpha/num_allies * sum([loss_ally[_].item() for _ in loss_ally]) - (1-alpha) * loss_advr # Backward pass loss_encd.backward() optimizer_encd.step() nsamples += 1 iloss += loss_encd.item() epochs_train.append(epoch) encd_loss_train.append(iloss/nsamples) encoder.eval() for j in range(i+1): ally[j].train() advr.train() nsamples = 0 iloss_ally = {} for j in range(i+1): iloss_ally[j] = 0 iloss_advr = 0 for data in dataloader_train: X_train_torch = data[0].to(device) y_advr_train_torch = data[1].to(device) y_ally_train_torch = {} for j in range(i+1): y_ally_train_torch[j] = data[j+2].to(device) y_ally_train_hat_torch = {} loss_ally = {} for j in range(i+1): optimizer_ally[j].zero_grad() X_train_encoded = encoder(X_train_torch) y_ally_train_hat_torch[j] = ally[j](X_train_encoded) loss_ally[j] = criterionBCEWithLogits( y_ally_train_hat_torch[j], y_ally_train_torch[j]) loss_ally[j].backward() optimizer_ally[j].step() optimizer_advr.zero_grad() X_train_encoded = encoder(X_train_torch) y_advr_train_hat_torch = advr(X_train_encoded) loss_advr = criterionBCEWithLogits( y_advr_train_hat_torch, y_advr_train_torch) loss_advr.backward() optimizer_advr.step() nsamples += 1 for j in range(i+1): iloss_ally[j] += loss_ally[j].item() iloss_advr += loss_advr.item() for j in range(i+1): ally_loss_train[j].append(iloss_ally[j]/nsamples) advr_loss_train.append(iloss_advr/nsamples) if epoch % int(n_epochs/10) != 0: continue encoder.eval() for j in range(i+1): ally[j].eval() advr.eval() nsamples = 0 iloss = 0 iloss_ally = {} for j in range(i+1): iloss_ally[j] = 0 iloss_advr = 0 for data in dataloader_valid: X_valid_torch = data[0].to(device) y_advr_valid_torch = data[1].to(device) y_ally_valid_torch = {} for j in range(i+1): y_ally_valid_torch[j] = data[j+2].to(device) X_valid_encoded = encoder(X_valid_torch) y_ally_valid_hat_torch = {} for j in range(i+1): y_ally_valid_hat_torch[j] = ally[j](X_valid_encoded) y_advr_valid_hat_torch = advr(X_valid_encoded) valid_loss_ally = {} for j in range(i+1): valid_loss_ally[j] = criterionBCEWithLogits( y_ally_valid_hat_torch[j], y_ally_valid_torch[j]) valid_loss_advr = criterionBCEWithLogits( y_advr_valid_hat_torch, y_advr_valid_torch) valid_loss_encd = alpha/num_allies*sum( [valid_loss_ally[_].item() for _ in valid_loss_ally] ) - (1-alpha)* valid_loss_advr nsamples += 1 iloss += valid_loss_encd.item() for j in range(i+1): iloss_ally[j] += valid_loss_ally[j].item() iloss_advr += valid_loss_advr.item() epochs_valid.append(epoch) encd_loss_valid.append(iloss/nsamples) for j in range(i+1): ally_loss_valid[j].append(iloss_ally[j]/nsamples) advr_loss_valid.append(iloss_advr/nsamples) log_line = [str(epoch), '{:.8f}'.format(encd_loss_train[-1]), '{:.8f}'.format(encd_loss_valid[-1]), '{:.8f}'.format(advr_loss_train[-1]), '{:.8f}'.format(advr_loss_valid[-1]), ] + \ [ '{:.8f} \t {:.8f}'.format( ally_loss_train[_][-1], ally_loss_valid[_][-1] ) for _ in ally_loss_train] logging.info(' \t '.join(log_line)) config_summary = '{}_n_{}_device_{}_dim_{}_hidden_{}_batch_{}_epochs_{}_ally_{}_encd_{:.4f}_advr_{:.4f}'\ .format( marker, num_allies, device, encoding_dim, hidden_dim, batch_size, n_epochs, i, encd_loss_train[-1], advr_loss_valid[-1], ) plt.figure() plt.plot(epochs_train, encd_loss_train, 'r', label='encd train') plt.plot(epochs_valid, encd_loss_valid, 'r--', label='encd valid') # sum_loss = [0] * len(ally_loss_train[0]) # for j in range(i+1): # for k in ally_loss_valid: # sum_loss[j] += ally_loss_train[j] # sum_loss = [sum_loss[j]/len(ally_loss_valid) for j in sum_loss] # plt.plot( # epochs_train, # sum([ally_loss_train[j] for j in range(i+1)])/len(ally_loss_train), # 'b', label='ally_sum_train') # sum_loss = [0] * len(ally_loss_valid) # for j in range(i+1): # sum_loss[j] += ally_loss_valid[j] # sum_loss = [sum_loss[j]/len(ally_loss_valid) for j in sum_loss] # plt.plot(epochs_valid, sum_loss, 'b', label='ally_sum_valid') plt.plot(epochs_train, advr_loss_train, 'g', label='advr_train') plt.plot(epochs_valid, advr_loss_valid, 'g--', label='advr_valid') plt.legend() plt.title("{} on {} training".format(model, expt)) plot_location = 'plots/{}/{}_training_{}_{}.png'.format( expt, model, time_stamp, config_summary) sep() logging.info('Saving: {}'.format(plot_location)) plt.savefig(plot_location) checkpoint_location = \ 'checkpoints/{}/{}_training_history_{}_{}.pkl'.format( expt, model, time_stamp, config_summary) logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(( epochs_train, epochs_valid, encd_loss_train, encd_loss_valid, ally_loss_train, ally_loss_valid, advr_loss_train, advr_loss_valid, ), open(checkpoint_location, 'wb')) model_ckpt = 'checkpoints/{}/{}_torch_model_{}_{}.pkl'.format( expt, model, time_stamp, config_summary) logging.info('Saving: {}'.format(model_ckpt)) torch.save(encoder, model_ckpt)
def main( model, time_stamp, device, ally_classes, advr_1_classes, advr_2_classes, encoding_dim, hidden_dim, leaky, test_size, batch_size, n_epochs, shuffle, lr_ally, lr_advr_1, lr_advr_2, expt, pca_ckpt, autoencoder_ckpt, encoder_ckpt, ): device = torch_device(device=device) X_normalized_train, X_normalized_valid,\ y_ally_train, y_ally_valid, \ y_advr_1_train, y_advr_1_valid, \ y_advr_2_train, y_advr_2_valid = get_data(expt, test_size) pca = joblib.load(pca_ckpt) optim = torch.optim.Adam criterionBCEWithLogits = nn.BCEWithLogitsLoss() criterionCrossEntropy = nn.CrossEntropyLoss() h = { 'epoch': { 'train': [], 'valid': [], }, 'pca': { 'ally_train': [], 'ally_valid': [], 'advr_1_train': [], 'advr_1_valid': [], 'advr_2_train': [], 'advr_2_valid': [], }, } for _ in ['pca']: if _ == 'pca': dataset_train = utils.TensorDataset( torch.Tensor(pca.eval(X_normalized_train)), torch.Tensor(y_ally_train.reshape(-1, ally_classes)), torch.Tensor(y_advr_1_train.reshape(-1, advr_1_classes)), torch.Tensor(y_advr_2_train.reshape(-1, advr_2_classes)), ) dataset_valid = utils.TensorDataset( torch.Tensor(pca.eval(X_normalized_valid)), torch.Tensor(y_ally_valid.reshape(-1, ally_classes)), torch.Tensor(y_advr_1_valid.reshape(-1, advr_1_classes)), torch.Tensor(y_advr_2_valid.reshape(-1, advr_2_classes)), ) dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=1) dataloader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=shuffle, num_workers=1) ally = DiscriminatorFCN(encoding_dim, hidden_dim, ally_classes, leaky).to(device) advr_1 = DiscriminatorFCN(encoding_dim, hidden_dim, advr_1_classes, leaky).to(device) advr_2 = DiscriminatorFCN(encoding_dim, hidden_dim, advr_2_classes, leaky).to(device) ally.apply(weights_init) advr_1.apply(weights_init) advr_2.apply(weights_init) sep('{}:{}'.format(_, 'ally')) summary(ally, input_size=(1, encoding_dim)) sep('{}:{}'.format(_, 'advr 1')) summary(advr_1, input_size=(1, encoding_dim)) sep('{}:{}'.format(_, 'advr 2')) summary(advr_2, input_size=(1, encoding_dim)) optimizer_ally = optim(ally.parameters(), lr=lr_ally) optimizer_advr_1 = optim(advr_1.parameters(), lr=lr_advr_1) optimizer_advr_2 = optim(advr_2.parameters(), lr=lr_advr_2) # adversary 1 sep("adversary 1") logging.info('{} \t {} \t {}'.format( 'Epoch', 'Advr 1 Train', 'Advr 1 Valid', )) for epoch in range(n_epochs): advr_1.train() nsamples = 0 iloss_advr = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_advr_train_torch = data[2].to(device) optimizer_advr_1.zero_grad() y_advr_train_hat_torch = advr_1(X_train_torch) loss_advr = criterionCrossEntropy( y_advr_train_hat_torch, torch.argmax(y_advr_train_torch, 1)) loss_advr.backward() optimizer_advr_1.step() nsamples += 1 iloss_advr += loss_advr.item() h[_]['advr_1_train'].append(iloss_advr / nsamples) if epoch % int(n_epochs / 10) != 0: continue advr_1.eval() nsamples = 0 iloss_advr = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = data[0].to(device) y_advr_valid_torch = data[2].to(device) y_advr_valid_hat_torch = advr_1(X_valid_torch) valid_loss_advr = criterionCrossEntropy( y_advr_valid_hat_torch, torch.argmax(y_advr_valid_torch, 1)) tmp, predicted = torch.max(y_advr_valid_hat_torch, 1) tmp, actual = torch.max(y_advr_valid_torch, 1) nsamples += 1 iloss_advr += valid_loss_advr.item() total += actual.size(0) correct += (predicted == actual).sum().item() h[_]['advr_1_valid'].append(iloss_advr / nsamples) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[_]['advr_1_train'][-1], h[_]['advr_1_valid'][-1], correct / total)) # adversary sep("adversary 2") logging.info('{} \t {} \t {}'.format( 'Epoch', 'Advr 2 Train', 'Advr 2 Valid', )) for epoch in range(n_epochs): advr_2.train() nsamples = 0 iloss_advr = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_advr_train_torch = data[3].to(device) optimizer_advr_2.zero_grad() y_advr_train_hat_torch = advr_2(X_train_torch) loss_advr = criterionBCEWithLogits(y_advr_train_hat_torch, y_advr_train_torch) loss_advr.backward() optimizer_advr_2.step() nsamples += 1 iloss_advr += loss_advr.item() h[_]['advr_2_train'].append(iloss_advr / nsamples) if epoch % int(n_epochs / 10) != 0: continue advr_2.eval() nsamples = 0 iloss_advr = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = data[0].to(device) y_advr_valid_torch = data[3].to(device) y_advr_valid_hat_torch = advr_2(X_valid_torch) valid_loss_advr = criterionBCEWithLogits( y_advr_valid_hat_torch, y_advr_valid_torch) predicted = y_advr_valid_hat_torch > 0.5 nsamples += 1 iloss_advr += valid_loss_advr.item() total += y_advr_valid_torch.size(0) correct += (predicted == y_advr_valid_torch).sum().item() h[_]['advr_2_valid'].append(iloss_advr / nsamples) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[_]['advr_2_train'][-1], h[_]['advr_2_valid'][-1], correct / total)) #ally sep("ally") logging.info('{} \t {} \t {} \t {}'.format( 'Epoch', 'Ally Train', 'Ally Valid', 'Accuracy', )) for epoch in range(n_epochs): ally.train() nsamples = 0 iloss_ally = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_ally_train_torch = data[1].to(device) optimizer_ally.zero_grad() y_ally_train_hat_torch = ally(X_train_torch) loss_ally = criterionBCEWithLogits(y_ally_train_hat_torch, y_ally_train_torch) loss_ally.backward() optimizer_ally.step() nsamples += 1 iloss_ally += loss_ally.item() if epoch not in h['epoch']['train']: h['epoch']['train'].append(epoch) h[_]['ally_train'].append(iloss_ally / nsamples) if epoch % int(n_epochs / 10) != 0: continue ally.eval() nsamples = 0 iloss_ally = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = data[0].to(device) y_ally_valid_torch = data[1].to(device) y_ally_valid_hat_torch = ally(X_valid_torch) valid_loss_ally = criterionBCEWithLogits( y_ally_valid_hat_torch, y_ally_valid_torch) predicted = y_ally_valid_hat_torch > 0.5 nsamples += 1 iloss_ally += valid_loss_ally.item() total += y_ally_valid_torch.size(0) correct += (predicted == y_ally_valid_torch).sum().item() if epoch not in h['epoch']['valid']: h['epoch']['valid'].append(epoch) h[_]['ally_valid'].append(iloss_ally / nsamples) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[_]['ally_train'][-1], h[_]['ally_valid'][-1], correct / total)) checkpoint_location = \ 'checkpoints/{}/{}_training_history_{}.pkl'.format( expt, model, time_stamp) sep() logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(h, open(checkpoint_location, 'wb'))