def comparison_argparse(debug=True): ap = argparse.ArgumentParser() ap.add_argument("--device", required=True) ap.add_argument("--n-ally", type=int, nargs='+', required=False) ap.add_argument("--n-advr", type=int, nargs='+', required=False) ap.add_argument("--dim", type=int, required=True) ap.add_argument("--hidden-dim", type=int, required=True) ap.add_argument("--leaky", type=int, required=True) ap.add_argument("--epsilon", type=float, required=False) ap.add_argument("--test-size", type=float, required=True) ap.add_argument("--batch-size", type=int, required=True) ap.add_argument("--n-epochs", type=int, required=True) ap.add_argument("--shuffle", type=int, required=True) ap.add_argument("--lr", type=float, required=False) ap.add_argument("--lr-ally", type=float, nargs='+', required=False) ap.add_argument("--lr-advr", type=float, nargs='+', required=False) ap.add_argument("--expt", required=True) ap.add_argument("--pca-ckpt", required=False) ap.add_argument("--autoencoder-ckpt", required=False) ap.add_argument("--encoder-ckpt", required=False) args = vars(ap.parse_args()) if debug: sep() logging.info(json.dumps(args, indent=2)) return args
def main( model, time_stamp, ally_classes, advr_1_classes, advr_2_classes, test_size, expl_var, expt, ): X, y_ally, y_advr_1, y_advr_2 = load_processed_data( expt, 'processed_data_X_y_ally_y_advr_y_advr_2.pkl') log_shapes([X, y_ally, y_advr_1, y_advr_2], locals(), 'Dataset loaded') X_train, X_valid = train_test_split( X, test_size=test_size, stratify=pd.DataFrame( np.concatenate(( y_ally.reshape(-1, ally_classes), y_advr_1.reshape(-1, advr_1_classes), y_advr_2.reshape(-1, advr_2_classes), ), axis=1))) log_shapes([ X_train, X_valid, ], locals(), 'Data size after train test split') scaler = StandardScaler() X_train_normalized = scaler.fit_transform(X_train) X_valid_normalized = scaler.transform(X_valid) log_shapes([X_train_normalized, X_valid_normalized], locals()) pca = PCABasic(expl_var) X_train_pca = pca.train(X_train_normalized) X_valid_pca = pca.eval(X_valid_normalized) sep() logging.info('\nExplained Variance: {}\nNum Components: {}'.format( str(expl_var), pca.num_components, )) config_summary = 'dim_{}'.format(pca.num_components) model_ckpt = 'checkpoints/{}/{}_sklearn_model_{}_{}.pkl'.format( expt, model, time_stamp, config_summary) sep() logging.info('Saving: {}'.format(model_ckpt)) joblib.dump(pca, model_ckpt)
def pca_argparse(debug=True): ap = argparse.ArgumentParser() ap.add_argument("--n-ally", type=int, nargs='+', required=False) ap.add_argument("--n-advr", type=int, nargs='+', required=False) ap.add_argument("--test-size", type=float, required=False) ap.add_argument("--expl-var", type=float, required=True) ap.add_argument("--expt", required=True) args = vars(ap.parse_args()) if debug: sep() logging.info(json.dumps(args, indent=2)) return args
def eigan_argparse(debug=True): ap = argparse.ArgumentParser() ap.add_argument("--device", required=True) ap.add_argument("--n-gpu", type=int, required=False) ap.add_argument("--n-nodes", type=int, required=False) ap.add_argument("--n-ally", type=int, nargs='+', required=False) ap.add_argument("--n-advr", type=int, nargs='+', required=False) ap.add_argument("--n-channels", type=int, required=False) ap.add_argument("--n-filters", type=int, required=False) ap.add_argument("--dim", type=int, required=False) ap.add_argument("--hidden-dim", type=int, required=False) ap.add_argument("--leaky", type=int, required=False) ap.add_argument("--activation", required=False) ap.add_argument("--test-size", type=float, required=False) ap.add_argument("--batch-size", type=int, required=True) ap.add_argument("--n-epochs", type=int, required=True) ap.add_argument("--shuffle", type=int, required=False) ap.add_argument("--init-w", type=int, required=False) ap.add_argument("--lr-encd", type=float, required=True) ap.add_argument("--lr-ally", type=float, nargs='+', required=True) ap.add_argument("--lr-advr", type=float, nargs='+', required=False) ap.add_argument("--alpha", type=float, required=False) ap.add_argument("--g-reps", type=int, required=False) ap.add_argument("--d-reps", type=int, required=False) ap.add_argument("--num-allies", type=int, required=False) ap.add_argument("--num-adversaries", type=int, required=False) ap.add_argument("--expt", required=True) ap.add_argument("--encd-ckpt", required=False) ap.add_argument("--ally-ckpts", nargs="+", required=False) ap.add_argument("--advr-ckpts", nargs="+", required=False) args = vars(ap.parse_args()) if debug: sep() logging.info(json.dumps(args, indent=2)) return args
def main( model, time_stamp, expl_var, expt, ): X_train, X_valid,\ y_train, y_valid = get_data(expt) pca = PCABasic(expl_var) pca.train(X_train.reshape(cfg.num_trains[expt], -1)) sep() logging.info('\nExplained Variance: {}\nNum Components: {}'.format( str(expl_var), pca.num_components, )) model_ckpt = 'ckpts/{}/models/{}_{}.pkl'.format(expt, model, marker) sep() logging.info('Saving: {}'.format(model_ckpt)) joblib.dump(pca, model_ckpt)
def main( model, time_stamp, device, encoding_dim, hidden_dim, leaky, test_size, batch_size, n_epochs, shuffle, lr, expt, pca_ckpt, autoencoder_ckpt, encoder_ckpt, ): device = torch_device(device=device) X, targets = load_processed_data( expt, 'processed_data_X_targets.pkl') log_shapes( [X] + [targets[i] for i in targets], locals(), 'Dataset loaded' ) targets = {i: elem.reshape(-1, 1) for i, elem in targets.items()} X_train, X_valid, \ y_adt_train, y_adt_valid = train_test_split( X, targets['admission_type'], test_size=test_size, stratify=pd.DataFrame(np.concatenate( ( targets['admission_type'], ), axis=1) ) ) log_shapes( [ X_train, X_valid, y_adt_train, y_adt_valid, ], locals(), 'Data size after train test split' ) y_train = y_adt_train y_valid = y_adt_valid scaler = StandardScaler() X_normalized_train = scaler.fit_transform(X_train) X_normalized_valid = scaler.transform(X_valid) log_shapes([X_normalized_train, X_normalized_valid], locals()) ckpts = { # 123A: checkpoints/mimic/n_ind_gan_training_history_02_03_2020_17_41_09.pkl # 0: 'checkpoints/mimic/n_eigan_torch_model_02_03_2020_16_13_39_A_n_1_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_0_encd_0.0471_advr_0.5991.pkl', # 28B: checkpoints/mimic/n_ind_gan_training_history_02_05_2020_00_23_29.pkl # 1: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_22_51_11_B_n_2_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_1_encd_0.0475_advr_0.5992.pkl', # 123A: checkpoints/mimic/n_ind_gan_training_history_02_03_2020_17_41_09.pkl # 2: 'checkpoints/mimic/n_eigan_torch_model_02_03_2020_16_14_37_A_n_3_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_2_encd_0.0464_advr_0.5991.pkl', # 224A: checkpoints/mimic/n_ind_gan_training_history_02_03_2020_20_09_39.pkl # 3: 'checkpoints/mimic/n_eigan_torch_model_02_03_2020_18_08_09_A_n_4_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_3_encd_0.0469_advr_0.5991.pkl', # 24A: checkpoints/mimic/n_ind_gan_training_history_02_04_2020_00_21_50.pkl # 4: 'checkpoints/mimic/n_eigan_torch_model_02_03_2020_23_12_05_A_n_5_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_4_encd_0.0468_advr_0.5994.pkl', # 67A: checkpoints/mimic/n_ind_gan_training_history_02_04_2020_05_30_09.pkl # 5: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_00_15_28_A_n_6_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_5_encd_0.0462_advr_0.5991.pkl', # 67A: checkpoints/mimic/n_ind_gan_training_history_02_04_2020_05_30_09.pkl # 6: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_00_42_31_A_n_7_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_6_encd_0.0453_advr_0.5992.pkl', # 28B: checkpoints/mimic/n_ind_gan_training_history_02_05_2020_00_23_29.pkl # 7: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_22_54_21_B_n_8_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_7_encd_0.0477_advr_0.5992.pkl', # 9B: checkpoints/mimic/n_ind_gan_training_history_02_05_2020_18_09_03.pkl # 8: 'checkpoints/mimic/n_eigan_torch_model_02_05_2020_00_48_34_B_n_9_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_8_encd_0.0473_advr_0.6000.pkl', # nA: heckpoints/mimic/n_ind_gan_training_history_02_04_2020_20_13_29.pkl # 9: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_18_51_01_A_n_10_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_9_encd_0.0420_advr_0.5992.pkl', } h = {} for idx, ckpt in ckpts.items(): encoder = torch.load(ckpt, map_location=device) encoder.eval() optim = torch.optim.Adam criterionBCEWithLogits = nn.BCEWithLogitsLoss() h[idx] = { 'epoch_train': [], 'epoch_valid': [], 'advr_train': [], 'advr_valid': [], } dataset_train = utils.TensorDataset( torch.Tensor(X_normalized_train), torch.Tensor(y_train), ) dataset_valid = utils.TensorDataset( torch.Tensor(X_normalized_valid), torch.Tensor(y_valid), ) def transform(input_arg): return encoder(input_arg) dataloader_train = torch.utils.data.DataLoader( dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=1 ) dataloader_valid = torch.utils.data.DataLoader( dataset_valid, batch_size=batch_size, shuffle=shuffle, num_workers=1 ) clf = DiscriminatorFCN( encoding_dim, hidden_dim, 1, leaky).to(device) clf.apply(weights_init) sep('{} {}'.format(idx+1, 'ally')) summary(clf, input_size=(1, encoding_dim)) optimizer = optim(clf.parameters(), lr=lr) # adversary 1 sep("adversary with {} ally encoder".format(idx+1)) logging.info('{} \t {} \t {}'.format( 'Epoch', 'Advr Train', 'Advr Valid', )) for epoch in range(n_epochs): clf.train() nsamples = 0 iloss_advr = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = transform(data[0].to(device)) y_advr_train_torch = data[1].to(device) optimizer.zero_grad() y_advr_train_hat_torch = clf(X_train_torch) loss_advr = criterionBCEWithLogits( y_advr_train_hat_torch, y_advr_train_torch) loss_advr.backward() optimizer.step() nsamples += 1 iloss_advr += loss_advr.item() h[idx]['advr_train'].append(iloss_advr/nsamples) h[idx]['epoch_train'].append(epoch) if epoch % int(n_epochs/10) != 0: continue clf.eval() nsamples = 0 iloss_advr = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = transform(data[0].to(device)) y_advr_valid_torch = data[1].to(device) y_advr_valid_hat_torch = clf(X_valid_torch) valid_loss_advr = criterionBCEWithLogits( y_advr_valid_hat_torch, y_advr_valid_torch,) predicted = y_advr_valid_hat_torch > 0.5 nsamples += 1 iloss_advr += valid_loss_advr.item() total += y_advr_valid_torch.size(0) correct += (predicted == y_advr_valid_torch).sum().item() h[idx]['advr_valid'].append(iloss_advr/nsamples) h[idx]['epoch_valid'].append(epoch) logging.info( '{} \t {:.8f} \t {:.8f} \t {:.8f}'. format( epoch, h[idx]['advr_train'][-1], h[idx]['advr_valid'][-1], correct/total )) checkpoint_location = \ 'checkpoints/{}/{}_training_history_{}.pkl'.format( expt, model, time_stamp) sep() logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(h, open(checkpoint_location, 'wb'))
def main( model, device, encoding_dim, batch_size, n_epochs, shuffle, lr, expt, ): device = torch_device(device=device) X_train, X_valid, \ y_train, y_valid = get_data(expt) dataset_train = utils.TensorDataset( torch.Tensor(X_train.reshape(cfg.num_trains[expt], -1))) dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=2) dataset_valid = utils.TensorDataset( torch.Tensor(X_valid.reshape(cfg.num_tests[expt], -1))) dataloader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers=2) auto_encoder = AutoEncoderBasic(input_size=cfg.input_sizes[expt], encoding_dim=encoding_dim).to(device) criterion = torch.nn.MSELoss() adam_optim = torch.optim.Adam optimizer = adam_optim(auto_encoder.parameters(), lr=lr) summary(auto_encoder, input_size=(1, cfg.input_sizes[expt])) h_epoch = [] h_valid = [] h_train = [] auto_encoder.train() sep() logging.info("epoch \t train \t valid") best = math.inf config_summary = 'device_{}_dim_{}_batch_{}_epochs_{}_lr_{}'.format( device, encoding_dim, batch_size, n_epochs, lr, ) for epoch in range(n_epochs): nsamples = 0 iloss = 0 for data in dataloader_train: optimizer.zero_grad() X_torch = data[0].to(device) X_torch_hat = auto_encoder(X_torch) loss = criterion(X_torch_hat, X_torch) loss.backward() optimizer.step() nsamples += 1 iloss += loss.item() if epoch % int(n_epochs / 10) != 0: continue h_epoch.append(epoch) h_train.append(iloss / nsamples) nsamples = 0 iloss = 0 for data in dataloader_valid: X_torch = data[0].to(device) X_torch_hat = auto_encoder(X_torch) loss = criterion(X_torch_hat, X_torch) nsamples += 1 iloss += loss.item() h_valid.append(iloss / nsamples) if h_valid[-1] < best: best = h_valid[-1] model_ckpt = 'ckpts/{}/models/{}_{}_{}.best'.format( expt, model, config_summary, marker) logging.info('Saving: {}'.format(model_ckpt)) torch.save(auto_encoder.state_dict(), model_ckpt) logging.info('{} \t {:.8f} \t {:.8f}'.format( h_epoch[-1], h_train[-1], h_valid[-1], )) fig = plt.figure(figsize=(5, 4)) ax = fig.add_subplot(111) ax.plot(h_epoch, h_train, 'r.:') ax.plot(h_epoch, h_valid, 'rs-.') ax.set_xlabel('epochs') ax.set_ylabel('loss (MSEE)') plt.legend(['train loss', 'valid loss']) plot_location = 'ckpts/{}/plots/{}_{}_{}.png'.format( expt, model, config_summary, marker) sep() logging.info('Saving: {}'.format(plot_location)) plt.savefig(plot_location) checkpoint_location = 'ckpts/{}/history/{}_{}_{}.pkl'.format( expt, model, config_summary, marker) logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump((h_epoch, h_train, h_valid), open(checkpoint_location, 'wb')) model_ckpt = 'ckpts/{}/models/{}_{}_{}.stop'.format( expt, model, config_summary, marker) logging.info('Saving: {}'.format(model_ckpt)) torch.save(auto_encoder.state_dict(), model_ckpt)
def main(model, time_stamp, device, ngpu, num_nodes, ally_classes, advr_1_classes, advr_2_classes, encoding_dim, hidden_dim, leaky, activation, test_size, batch_size, n_epochs, shuffle, init_weight, lr_encd, lr_ally, lr_advr_1, lr_advr_2, alpha, g_reps, d_reps, expt, marker): device = torch_device(device=device) X, y_ally, y_advr_1, y_advr_2 = load_processed_data( expt, 'processed_data_X_y_ally_y_advr_y_advr_2.pkl') log_shapes([X, y_ally, y_advr_1, y_advr_2], locals(), 'Dataset loaded') X_train, X_valid, \ y_ally_train, y_ally_valid, \ y_advr_1_train, y_advr_1_valid, \ y_advr_2_train, y_advr_2_valid = train_test_split( X, y_ally, y_advr_1, y_advr_2, test_size=test_size, stratify=pd.DataFrame(np.concatenate( ( y_ally.reshape(-1, ally_classes), y_advr_1.reshape(-1, advr_1_classes), y_advr_2.reshape(-1, advr_2_classes), ), axis=1) ) ) log_shapes([ X_train, X_valid, y_ally_train, y_ally_valid, y_advr_1_train, y_advr_1_valid, y_advr_2_train, y_advr_2_valid, ], locals(), 'Data size after train test split') num_features = X_train.shape[1] split_pts = num_features // num_nodes X_train, X_valid = np.split(X_train, num_nodes, axis=1), np.split(X_valid, num_nodes, axis=1) log_shapes(X_train + X_valid, locals(), 'Data size after splitting data among nodes') scaler = StandardScaler() X_normalized_train, X_normalized_valid = [], [] for train, valid in zip(X_train, X_valid): X_normalized_train.append(scaler.fit_transform(train)) X_normalized_valid.append(scaler.transform(valid)) log_shapes(X_normalized_train + X_normalized_valid, locals()) encoders = [] allies = [] adversaries_1 = [] adversaries_2 = [] for train in X_normalized_train: encoders.append( GeneratorFCN(train.shape[1], hidden_dim, encoding_dim // num_nodes, leaky, activation).to(device)) allies.append( DiscriminatorFCN(encoding_dim // num_nodes, hidden_dim, ally_classes, leaky).to(device)) adversaries_1.append( DiscriminatorFCN(encoding_dim // num_nodes, hidden_dim, advr_1_classes, leaky).to(device)) adversaries_2.append( DiscriminatorFCN(encoding_dim // num_nodes, hidden_dim, advr_2_classes, leaky).to(device)) sep('encoders') for k in range(num_nodes): summary(encoders[k], input_size=(1, X_normalized_train[k].shape[1])) sep('ally') for k in range(num_nodes): summary(allies[k], input_size=(1, encoding_dim // num_nodes)) sep('advr_1') for k in range(num_nodes): summary(adversaries_1[k], input_size=(1, encoding_dim // num_nodes)) sep('advr_2') for k in range(num_nodes): summary(adversaries_2[k], input_size=(1, encoding_dim // num_nodes)) optim = torch.optim.Adam criterionBCEWithLogits = nn.BCEWithLogitsLoss() optimizers_encd = [] for encoder in encoders: optimizers_encd.append(optim(encoder.parameters(), lr=lr_encd)) optimizers_ally = [] for ally in allies: optimizers_ally.append(optim(ally.parameters(), lr=lr_ally)) optimizers_advr_1 = [] for advr_1 in adversaries_1: optimizers_advr_1.append(optim(advr_1.parameters(), lr=lr_advr_1)) optimizers_advr_2 = [] for advr_2 in adversaries_2: optimizers_advr_2.append(optim(advr_2.parameters(), lr=lr_advr_2)) for k in range(num_nodes): sep('Node {}'.format(k)) dataset_train = utils.TensorDataset( torch.Tensor(X_normalized_train[k]), torch.Tensor(y_ally_train.reshape(-1, ally_classes)), torch.Tensor(y_advr_1_train.reshape(-1, advr_1_classes)), torch.Tensor(y_advr_2_train.reshape(-1, advr_2_classes)), ) dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=1) dataset_valid = utils.TensorDataset( torch.Tensor(X_normalized_valid[k]), torch.Tensor(y_ally_valid.reshape(-1, ally_classes)), torch.Tensor(y_advr_1_valid.reshape(-1, advr_1_classes)), torch.Tensor(y_advr_2_valid.reshape(-1, advr_2_classes)), ) dataloader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=shuffle, num_workers=1) epochs_train = [] epochs_valid = [] encd_loss_train = [] encd_loss_valid = [] ally_loss_train = [] ally_loss_valid = [] advr_1_loss_train = [] advr_1_loss_valid = [] advr_2_loss_train = [] advr_2_loss_valid = [] logging.info( '{} \t {} \t {} \t {} \t {} \t {} \t {} \t {} \t {}'.format( 'Epoch', 'Encd Train', 'Encd Valid', 'Ally Train', 'Ally Valid', 'Advr 1 Train', 'Advr 1 Valid', 'Advr 2 Train', 'Advr 2 Valid', )) encoder = encoders[k] ally = allies[k] advr_1 = adversaries_1[k] advr_2 = adversaries_2[k] optimizer_encd = optimizers_encd[k] optimizer_ally = optimizers_ally[k] optimizer_advr_1 = optimizers_advr_1[k] optimizer_advr_2 = optimizers_advr_2[k] for epoch in range(n_epochs): encoder.train() ally.eval() advr_1.eval() advr_2.eval() for __ in range(g_reps): nsamples = 0 iloss = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_ally_train_torch = data[1].to(device) y_advr_1_train_torch = data[2].to(device) y_advr_2_train_torch = data[3].to(device) optimizer_encd.zero_grad() # Forward pass X_train_encoded = encoder(X_train_torch) y_ally_train_hat_torch = ally(X_train_encoded) y_advr_1_train_hat_torch = advr_1(X_train_encoded) y_advr_2_train_hat_torch = advr_2(X_train_encoded) # Compute Loss loss_ally = criterionBCEWithLogits(y_ally_train_hat_torch, y_ally_train_torch) loss_advr_1 = criterionBCEWithLogits( y_advr_1_train_hat_torch, y_advr_1_train_torch) loss_advr_2 = criterionBCEWithLogits( y_advr_2_train_hat_torch, y_advr_2_train_torch) loss_encd = alpha * loss_ally - (1 - alpha) / 2 * ( loss_advr_1 + loss_advr_2) # Backward pass loss_encd.backward() optimizer_encd.step() nsamples += 1 iloss += loss_encd.item() epochs_train.append(epoch) encd_loss_train.append(iloss / nsamples) encoder.eval() ally.train() advr_1.train() advr_2.train() for __ in range(d_reps): nsamples = 0 iloss_ally = 0 iloss_advr_1 = 0 iloss_advr_2 = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_ally_train_torch = data[1].to(device) y_advr_1_train_torch = data[2].to(device) y_advr_2_train_torch = data[3].to(device) optimizer_ally.zero_grad() X_train_encoded = encoder(X_train_torch) y_ally_train_hat_torch = ally(X_train_encoded) loss_ally = criterionBCEWithLogits(y_ally_train_hat_torch, y_ally_train_torch) loss_ally.backward() optimizer_ally.step() optimizer_advr_1.zero_grad() X_train_encoded = encoder(X_train_torch) y_advr_1_train_hat_torch = advr_1(X_train_encoded) loss_advr_1 = criterionBCEWithLogits( y_advr_1_train_hat_torch, y_advr_1_train_torch) loss_advr_1.backward() optimizer_advr_1.step() optimizer_advr_2.zero_grad() X_train_encoded = encoder(X_train_torch) y_advr_2_train_hat_torch = advr_2(X_train_encoded) loss_advr_2 = criterionBCEWithLogits( y_advr_2_train_hat_torch, y_advr_2_train_torch) loss_advr_2.backward() optimizer_advr_2.step() nsamples += 1 iloss_ally += loss_ally.item() iloss_advr_1 += loss_advr_1.item() iloss_advr_2 += loss_advr_2.item() ally_loss_train.append(iloss_ally / nsamples) advr_1_loss_train.append(iloss_advr_1 / nsamples) advr_2_loss_train.append(iloss_advr_2 / nsamples) if epoch % int(n_epochs / 10) != 0: continue encoder.eval() ally.eval() advr_1.eval() advr_2.eval() nsamples = 0 iloss = 0 iloss_ally = 0 iloss_advr_1 = 0 iloss_advr_2 = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = data[0].to(device) y_ally_valid_torch = data[1].to(device) y_advr_1_valid_torch = data[2].to(device) y_advr_2_valid_torch = data[3].to(device) X_valid_encoded = encoder(X_valid_torch) y_ally_valid_hat_torch = ally(X_valid_encoded) y_advr_1_valid_hat_torch = advr_1(X_valid_encoded) y_advr_2_valid_hat_torch = advr_2(X_valid_encoded) valid_loss_ally = criterionBCEWithLogits( y_ally_valid_hat_torch, y_ally_valid_torch) valid_loss_advr_1 = criterionBCEWithLogits( y_advr_1_valid_hat_torch, y_advr_1_valid_torch) valid_loss_advr_2 = criterionBCEWithLogits( y_advr_2_valid_hat_torch, y_advr_2_valid_torch) valid_loss_encd = alpha*valid_loss_ally - (1-alpha)/2*(valid_loss_advr_1 + \ valid_loss_advr_2) nsamples += 1 iloss += valid_loss_encd.item() iloss_ally += valid_loss_ally.item() iloss_advr_1 += valid_loss_advr_1.item() iloss_advr_2 += valid_loss_advr_2.item() epochs_valid.append(epoch) encd_loss_valid.append(iloss / nsamples) ally_loss_valid.append(iloss_ally / nsamples) advr_1_loss_valid.append(iloss_advr_1 / nsamples) advr_2_loss_valid.append(iloss_advr_2 / nsamples) logging.info( '{} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f}' .format( epoch, encd_loss_train[-1], encd_loss_valid[-1], ally_loss_train[-1], ally_loss_valid[-1], advr_1_loss_train[-1], advr_1_loss_valid[-1], advr_2_loss_train[-1], advr_2_loss_valid[-1], )) config_summary = '{}_node_{}_{}_device_{}_dim_{}_hidden_{}_batch_{}_epochs_{}_lrencd_{}_lrally_{}_tr_{:.4f}_val_{:.4f}'\ .format( marker, num_nodes, k, device, encoding_dim, hidden_dim, batch_size, n_epochs, lr_encd, lr_ally, encd_loss_train[-1], advr_1_loss_valid[-1], ) plt.plot(epochs_train, encd_loss_train, 'r') plt.plot(epochs_valid, encd_loss_valid, 'r--') plt.plot(epochs_train, ally_loss_train, 'b') plt.plot(epochs_valid, ally_loss_valid, 'b--') plt.plot(epochs_train, advr_1_loss_train, 'g') plt.plot(epochs_valid, advr_1_loss_valid, 'g--') plt.plot(epochs_train, advr_2_loss_train, 'y') plt.plot(epochs_valid, advr_2_loss_valid, 'y--') plt.legend([ 'encoder train', 'encoder valid', 'ally train', 'ally valid', 'advr 1 train', 'advr 1 valid', 'advr 2 train', 'advr 2 valid', ]) plt.title("{}:{}/{} on {} training".format(model, k, num_nodes, expt)) plot_location = 'plots/{}/{}_{}_{}_training_{}_{}.png'.format( expt, model, num_nodes, k, time_stamp, config_summary) sep() logging.info('Saving: {}'.format(plot_location)) plt.savefig(plot_location) checkpoint_location = \ 'checkpoints/{}/{}_{}_{}_training_history_{}_{}.pkl'.format( expt, model, num_nodes, k, time_stamp, config_summary) logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(( epochs_train, epochs_valid, encd_loss_train, encd_loss_valid, ally_loss_train, ally_loss_valid, advr_1_loss_train, advr_1_loss_valid, advr_2_loss_train, advr_2_loss_valid, ), open(checkpoint_location, 'wb')) model_ckpt = 'checkpoints/{}/{}_{}_{}_torch_model_{}_{}.pkl'.format( expt, model, num_nodes, k, time_stamp, config_summary) logging.info('Saving: {}'.format(model_ckpt)) torch.save(encoder, model_ckpt)
def main(expt, model): pca_1 = pkl.load( open( 'checkpoints/mimic/ind_pca_training_history_01_20_2020_23_31_01.pkl', 'rb')) pca_2 = pkl.load( open( 'checkpoints/mimic/ind_pca_training_history_01_21_2020_00_19_41.pkl', 'rb')) auto_1 = pkl.load( open( 'checkpoints/mimic/ind_autoencoder_training_history_01_24_2020_13_50_25.pkl', 'rb')) dp_1 = pkl.load( open( 'checkpoints/mimic/ind_dp_training_history_01_24_2020_07_31_44.pkl', 'rb')) gan_1 = pkl.load( open( 'checkpoints/mimic/ind_gan_training_history_01_24_2020_13_16_16.pkl', 'rb')) gan_2 = pkl.load( open( 'checkpoints/mimic/ind_gan_training_history_01_25_2020_02_04_34.pkl', 'rb')) # gan_1 = pkl.load(open('checkpoints/mimic/ind_gan_training_history_01_27_2020_00_57_38.pkl', 'rb')) # gan_2 = gan_1 # print(pca_1.keys(), pca_2.keys(), auto_1.keys(), auto_2.keys(), dp_1.keys(), gan_1.keys()) # return plt.figure() fig = plt.figure(figsize=(15, 3)) ax3 = fig.add_subplot(131) ax1 = fig.add_subplot(132) ax2 = fig.add_subplot(133) t3, t1, t2 = '(a)', '(b)', '(c)' ax3.plot(pca_1['epoch']['valid'], gan_1['encoder']['ally_valid'], 'r') ax3.plot(pca_1['epoch']['valid'], pca_1['pca']['ally_valid'], 'g') ax3.plot(pca_1['epoch']['valid'], auto_1['autoencoder']['ally_valid'], 'b') ax3.plot(pca_1['epoch']['valid'], dp_1['dp']['ally_valid'], 'y') ax3.legend([ 'EIGAN ally', 'Autoencoder ally', 'PCA ally', 'DP ally', ]) ax3.set_title(t3, y=-0.3) ax3.set_xlabel('epochs') ax3.set_ylabel('log loss') ax3.grid() ax3.text(320, 0.618, 'Lower is better', fontsize=12, color='r') ax3.set_ylim(bottom=0.58) ax1.plot(pca_1['epoch']['valid'], gan_2['encoder']['advr_1_valid'], 'r--') ax1.plot(pca_1['epoch']['valid'], auto_1['autoencoder']['advr_1_valid'], 'b--') ax1.plot(pca_1['epoch']['valid'], pca_1['pca']['advr_1_valid'], 'g--') ax1.plot(pca_1['epoch']['valid'], dp_1['dp']['advr_1_valid'], 'y--') ax1.legend([ 'EIGAN adversary 1', 'Autoencoder adversary 1', 'PCA adversary 1', 'DP adversary 1', ]) ax1.set_title(t1, y=-0.3) ax1.set_xlabel('epochs') ax1.set_ylabel('log loss') ax1.grid() ax1.text(320, 0.67, 'Higher is better', fontsize=12, color='r') ax1.set_ylim(bottom=0.66) ax2.plot(pca_1['epoch']['valid'], gan_1['encoder']['advr_2_valid'], 'r--') ax2.plot(pca_1['epoch']['valid'], auto_1['autoencoder']['advr_2_valid'], 'b--') ax2.plot(pca_1['epoch']['valid'], pca_2['pca']['advr_2_valid'], 'g--') ax2.plot(pca_1['epoch']['valid'], dp_1['dp']['advr_2_valid'], 'y--') ax2.legend([ 'EIGAN adversary 2', 'Autoencoder adversary 2', 'PCA adversary 2', 'DP adversary 2', ]) ax2.set_title(t2, y=-0.3) ax2.set_xlabel('epochs') ax2.set_ylabel('log loss') ax2.grid() ax2.text(320, 0.56, 'Higher is better', fontsize=12, color='r') ax2.set_ylim(bottom=0.54, top=0.64) fig.subplots_adjust(wspace=0.3) plot_location = 'plots/{}/{}_{}.png'.format(expt, 'all', model) sep() logging.info('Saving: {}'.format(plot_location)) plt.savefig(plot_location, bbox_inches='tight')
def main( model, time_stamp, device, ngpu, ally_classes, advr_1_classes, advr_2_classes, encoding_dim, hidden_dim, leaky, activation, test_size, batch_size, n_epochs, shuffle, init_weight, lr_encd, lr_ally, lr_advr_1, lr_advr_2, alpha, g_reps, d_reps, expt, marker ): device = torch_device(device=device) X, y_ally, y_advr_1, y_advr_2 = load_processed_data( expt, 'processed_data_X_y_ally_y_advr_y_advr_2.pkl') log_shapes( [X, y_ally, y_advr_1, y_advr_2], locals(), 'Dataset loaded' ) X_train, X_valid, \ y_ally_train, y_ally_valid, \ y_advr_1_train, y_advr_1_valid, \ y_advr_2_train, y_advr_2_valid = train_test_split( X, y_ally, y_advr_1, y_advr_2, test_size=test_size, stratify=pd.DataFrame(np.concatenate( ( y_ally.reshape(-1, ally_classes), y_advr_1.reshape(-1, advr_1_classes), y_advr_2.reshape(-1, advr_2_classes), ), axis=1) ) ) log_shapes( [ X_train, X_valid, y_ally_train, y_ally_valid, y_advr_1_train, y_advr_1_valid, y_advr_2_train, y_advr_2_valid, ], locals(), 'Data size after train test split' ) scaler = StandardScaler() X_normalized_train = scaler.fit_transform(X_train) X_normalized_valid = scaler.transform(X_valid) log_shapes([X_normalized_train, X_normalized_valid], locals()) encoder = GeneratorFCN( X_normalized_train.shape[1], hidden_dim, encoding_dim, leaky, activation).to(device) ally = DiscriminatorFCN( encoding_dim, hidden_dim, ally_classes, leaky).to(device) advr_1 = DiscriminatorFCN( encoding_dim, hidden_dim, advr_1_classes, leaky).to(device) advr_2 = DiscriminatorFCN( encoding_dim, hidden_dim, advr_2_classes, leaky).to(device) if init_weight: sep() logging.info('applying weights_init ...') encoder.apply(weights_init) ally.apply(weights_init) advr_1.apply(weights_init) advr_2.apply(weights_init) sep('encoder') summary(encoder, input_size=(1, X_normalized_train.shape[1])) sep('ally') summary(ally, input_size=(1, encoding_dim)) sep('advr_1') summary(advr_1, input_size=(1, encoding_dim)) sep('advr_2') summary(advr_2, input_size=(1, encoding_dim)) optim = torch.optim.Adam criterionBCEWithLogits = nn.BCEWithLogitsLoss() criterionCrossEntropy = nn.CrossEntropyLoss() optimizer_encd = optim( encoder.parameters(), lr=lr_encd, weight_decay=lr_encd ) optimizer_ally = optim( ally.parameters(), lr=lr_ally, weight_decay=lr_ally ) optimizer_advr_1 = optim( advr_1.parameters(), lr=lr_advr_1, weight_decay=lr_advr_1 ) optimizer_advr_2 = optim( advr_2.parameters(), lr=lr_advr_2, weight_decay=lr_advr_2 ) dataset_train = utils.TensorDataset( torch.Tensor(X_normalized_train), torch.Tensor(y_ally_train.reshape(-1, 1)), torch.Tensor(y_advr_1_train.reshape(-1, 1)), torch.Tensor(y_advr_2_train.reshape(-1, 3)), ) dataloader_train = torch.utils.data.DataLoader( dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=1 ) dataset_valid = utils.TensorDataset( torch.Tensor(X_normalized_valid), torch.Tensor(y_ally_valid.reshape(-1, 1)), torch.Tensor(y_advr_1_valid.reshape(-1, 1)), torch.Tensor(y_advr_2_valid.reshape(-1, 3)), ) dataloader_valid = torch.utils.data.DataLoader( dataset_valid, batch_size=batch_size, shuffle=shuffle, num_workers=1 ) epochs_train = [] epochs_valid = [] encd_loss_train = [] encd_loss_valid = [] ally_loss_train = [] ally_loss_valid = [] advr_1_loss_train = [] advr_1_loss_valid = [] advr_2_loss_train = [] advr_2_loss_valid = [] logging.info('{} \t {} \t {} \t {} \t {} \t {} \t {} \t {} \t {}'.format( 'Epoch', 'Encd Train', 'Encd Valid', 'Ally Train', 'Ally Valid', 'Advr 1 Train', 'Advr 1 Valid', 'Advr 2 Train', 'Advr 2 Valid', )) for epoch in range(n_epochs): encoder.train() ally.eval() advr_1.eval() advr_2.eval() for __ in range(g_reps): nsamples = 0 iloss = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_ally_train_torch = data[1].to(device) y_advr_1_train_torch = data[2].to(device) y_advr_2_train_torch = data[3].to(device) optimizer_encd.zero_grad() # Forward pass X_train_encoded = encoder(X_train_torch) y_ally_train_hat_torch = ally(X_train_encoded) y_advr_1_train_hat_torch = advr_1(X_train_encoded) y_advr_2_train_hat_torch = advr_2(X_train_encoded) # Compute Loss loss_ally = criterionBCEWithLogits( y_ally_train_hat_torch, y_ally_train_torch) loss_advr_1 = criterionBCEWithLogits( y_advr_1_train_hat_torch, y_advr_1_train_torch) loss_advr_2 = criterionCrossEntropy( y_advr_2_train_hat_torch, torch.argmax(y_advr_2_train_torch, 1)) loss_encd = loss_ally - loss_advr_1 - loss_advr_2 # Backward pass loss_encd.backward() optimizer_encd.step() nsamples += 1 iloss += loss_encd.item() epochs_train.append(epoch) encd_loss_train.append(iloss/nsamples) encoder.eval() ally.train() advr_1.train() advr_2.train() for __ in range(d_reps): nsamples = 0 iloss_ally = 0 iloss_advr_1 = 0 iloss_advr_2 = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_ally_train_torch = data[1].to(device) y_advr_1_train_torch = data[2].to(device) y_advr_2_train_torch = data[3].to(device) optimizer_ally.zero_grad() X_train_encoded = encoder(X_train_torch) y_ally_train_hat_torch = ally(X_train_encoded) loss_ally = criterionBCEWithLogits( y_ally_train_hat_torch, y_ally_train_torch) loss_ally.backward() optimizer_ally.step() optimizer_advr_1.zero_grad() X_train_encoded = encoder(X_train_torch) y_advr_1_train_hat_torch = advr_1(X_train_encoded) loss_advr_1 = criterionBCEWithLogits( y_advr_1_train_hat_torch, y_advr_1_train_torch) loss_advr_1.backward() optimizer_advr_1.step() optimizer_advr_2.zero_grad() X_train_encoded = encoder(X_train_torch) y_advr_2_train_hat_torch = advr_2(X_train_encoded) loss_advr_2 = criterionCrossEntropy( y_advr_2_train_hat_torch, torch.argmax(y_advr_2_train_torch, 1)) loss_advr_2.backward() optimizer_advr_2.step() nsamples += 1 iloss_ally += loss_ally.item() iloss_advr_1 += loss_advr_1.item() iloss_advr_2 += loss_advr_2.item() ally_loss_train.append(iloss_ally/nsamples) advr_1_loss_train.append(iloss_advr_1/nsamples) advr_2_loss_train.append(iloss_advr_2/nsamples) if epoch % int(n_epochs/10) != 0: continue encoder.eval() ally.eval() advr_1.eval() advr_2.eval() nsamples = 0 iloss = 0 iloss_ally = 0 iloss_advr_1 = 0 iloss_advr_2 = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = data[0].to(device) y_ally_valid_torch = data[1].to(device) y_advr_1_valid_torch = data[2].to(device) y_advr_2_valid_torch = data[3].to(device) X_valid_encoded = encoder(X_valid_torch) y_ally_valid_hat_torch = ally(X_valid_encoded) y_advr_1_valid_hat_torch = advr_1(X_valid_encoded) y_advr_2_valid_hat_torch = advr_2(X_valid_encoded) valid_loss_ally = criterionBCEWithLogits( y_ally_valid_hat_torch, y_ally_valid_torch) valid_loss_advr_1 = criterionBCEWithLogits( y_advr_1_valid_hat_torch, y_advr_1_valid_torch) valid_loss_advr_2 = criterionCrossEntropy( y_advr_2_valid_hat_torch, torch.argmax(y_advr_2_valid_torch, 1)) valid_loss_encd = valid_loss_ally - valid_loss_advr_1 - \ valid_loss_advr_2 nsamples += 1 iloss += valid_loss_encd.item() iloss_ally += valid_loss_ally.item() iloss_advr_1 += valid_loss_advr_1.item() iloss_advr_2 += valid_loss_advr_2.item() epochs_valid.append(epoch) encd_loss_valid.append(iloss/nsamples) ally_loss_valid.append(iloss_ally/nsamples) advr_1_loss_valid.append(iloss_advr_1/nsamples) advr_2_loss_valid.append(iloss_advr_2/nsamples) logging.info( '{} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f}'. format( epoch, encd_loss_train[-1], encd_loss_valid[-1], ally_loss_train[-1], ally_loss_valid[-1], advr_1_loss_train[-1], advr_1_loss_valid[-1], advr_2_loss_train[-1], advr_2_loss_valid[-1], )) config_summary = '{}_device_{}_dim_{}_hidden_{}_batch_{}_epochs_{}_lrencd_{}_lrally_{}_tr_{:.4f}_val_{:.4f}'\ .format( marker, device, encoding_dim, hidden_dim, batch_size, n_epochs, lr_encd, lr_ally, encd_loss_train[-1], advr_1_loss_valid[-1], ) plt.plot(epochs_train, encd_loss_train, 'r') plt.plot(epochs_valid, encd_loss_valid, 'r--') plt.plot(epochs_train, ally_loss_train, 'b') plt.plot(epochs_valid, ally_loss_valid, 'b--') plt.plot(epochs_train, advr_1_loss_train, 'g') plt.plot(epochs_valid, advr_1_loss_valid, 'g--') plt.plot(epochs_train, advr_2_loss_train, 'y') plt.plot(epochs_valid, advr_2_loss_valid, 'y--') plt.legend([ 'encoder train', 'encoder valid', 'ally train', 'ally valid', 'advr 1 train', 'advr 1 valid', 'advr 2 train', 'advr 2 valid', ]) plt.title("{} on {} training".format(model, expt)) plot_location = 'plots/{}/{}_training_{}_{}.png'.format( expt, model, time_stamp, config_summary) sep() logging.info('Saving: {}'.format(plot_location)) plt.savefig(plot_location) checkpoint_location = \ 'checkpoints/{}/{}_training_history_{}_{}.pkl'.format( expt, model, time_stamp, config_summary) logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(( epochs_train, epochs_valid, encd_loss_train, encd_loss_valid, ally_loss_train, ally_loss_valid, advr_1_loss_train, advr_1_loss_valid, advr_2_loss_train, advr_2_loss_valid, ), open(checkpoint_location, 'wb')) model_ckpt = 'checkpoints/{}/{}_torch_model_{}_{}.pkl'.format( expt, model, time_stamp, config_summary) logging.info('Saving: {}'.format(model_ckpt)) torch.save(encoder, model_ckpt)
def main(expt, model): pca_1 = pkl.load(open('checkpoints/mimic_centralized/ind_pca_training_history_01_20_2020_23_31_01.pkl', 'rb')) pca_2 = pkl.load(open('checkpoints/mimic_centralized/ind_pca_training_history_01_21_2020_00_19_41.pkl', 'rb')) auto_1 = pkl.load(open('checkpoints/mimic_centralized/ind_autoencoder_training_history_01_24_2020_13_50_25.pkl', 'rb')) dp_1 = pkl.load(open('checkpoints/mimic_centralized/ind_dp_training_history_01_24_2020_07_31_44.pkl', 'rb')) gan_1 = pkl.load(open('checkpoints/mimic_centralized/ind_gan_training_history_01_24_2020_13_16_16.pkl', 'rb')) gan_2 = pkl.load(open('checkpoints/mimic_centralized/ind_gan_training_history_01_25_2020_02_04_34.pkl', 'rb')) st_1_2 = pkl.load(open('checkpoints/mimic/ind_gan_dist_s1_2_training_history_04_07_2020_15_44_42.pkl', 'rb')) st_1_3 = pkl.load(open('checkpoints/mimic/ind_gan_dist_s1_nodes_3_training_history_04_16_2020_23_06_10.pkl', 'rb')) st_1_4 = pkl.load(open('checkpoints/mimic/ind_gan_dist_s1_nodes_4_training_history_04_17_2020_00_29_50.pkl', 'rb')) st_1_6 = pkl.load(open('checkpoints/mimic/ind_gan_dist_s1_nodes_6_training_history_04_17_2020_01_23_53.pkl', 'rb')) st_1_8 = pkl.load(open('checkpoints/mimic/ind_gan_dist_s1_nodes_8_training_history_04_17_2020_12_58_18.pkl', 'rb')) st_1_12 = pkl.load(open('checkpoints/mimic/ind_gan_dist_s1_nodes_12_training_history_04_17_2020_14_40_30.pkl', 'rb')) # gan_1 = pkl.load(open('checkpoints/mimic/ind_gan_training_history_01_27_2020_00_57_38.pkl', 'rb')) # gan_2 = gan_1 s = pkl.load(open('checkpoints/mimic_centralized/n_eigan_training_history_02_03_2020_00_59_27_B_device_cuda_dim_256_hidden_512_batch_16384_epochs_1001_ally_0_encd_0.0276_advr_0.5939.pkl','rb')) # print(pca_1.keys(), pca_2.keys(), auto_1.keys(), auto_2.keys(), dp_1.keys(), gan_1.keys()) # return plt.figure() fig = plt.figure(figsize=(15, 4)) ax1 = fig.add_subplot(131) ax3 = fig.add_subplot(132) ax2 = fig.add_subplot(133) t1, t3, t2 = '(a)', '(b)', '(c)' ax3.plot(pca_1['epoch']['valid'], gan_1['encoder']['ally_valid'], 'r') # ax3.plot(pca_1['epoch']['valid'], st_1_2['encoder']['ally_valid'], 'k*') # ax3.plot(pca_1['epoch']['valid'], st_1_3['encoder']['ally_valid'], 'k--') # ax3.plot(pca_1['epoch']['valid'], st_1_4['encoder']['ally_valid'], 'k+') # ax3.plot(pca_1['epoch']['valid'], st_1_6['encoder']['ally_valid'], 'ks') # ax3.plot(pca_1['epoch']['valid'], st_1_8['encoder']['ally_valid'], 'k.') # ax3.plot(pca_1['epoch']['valid'], st_1_12['encoder']['ally_valid'], 'k') ax3.plot(pca_1['epoch']['valid'], pca_1['pca']['ally_valid'], 'g') ax3.plot(pca_1['epoch']['valid'], auto_1['autoencoder']['ally_valid'], 'b') ax3.plot(pca_1['epoch']['valid'], dp_1['dp']['ally_valid'], 'y') ax3.legend([ 'EIGAN', # '2 nodes', # '3 nodes', # '4 nodes', # '6 nodes', # '8 nodes', # '12 nodes', 'Autoencoder', 'PCA', 'DP', ],prop={'size':10}) ax3.set_title(t3, y=-0.32) ax3.set_xlabel('epochs') ax3.set_ylabel('ally log loss') ax3.grid() ax3.text(320,0.618, 'Lower is better', fontsize=14, color='r') ax3.set_ylim(bottom=0.58, top=0.8) ax3.set_xlim(left=0, right=1000) ax1.plot(pca_1['epoch']['valid'], gan_2['encoder']['advr_1_valid'], 'r', label='EIGAN') ax1.plot(pca_1['epoch']['valid'], gan_1['encoder']['advr_2_valid'], 'r--') # ax1.plot(pca_1['epoch']['valid'], st_1_2['encoder']['advr_1_valid'], 'k*', label='2 nodes') # ax1.plot(pca_1['epoch']['valid'], st_1_3['encoder']['advr_1_valid'], 'k--', label='3 nodes') # ax1.plot(pca_1['epoch']['valid'], st_1_4['encoder']['advr_1_valid'], 'k+', label='4 nodes') # ax1.plot(pca_1['epoch']['valid'], st_1_6['encoder']['advr_1_valid'], 'ks', label='6 nodes') # ax1.plot(pca_1['epoch']['valid'], st_1_8['encoder']['advr_1_valid'], 'k.', label='8 nodes') # ax1.plot(pca_1['epoch']['valid'], st_1_12['encoder']['advr_1_valid'], 'k', label='12 nodes') ax1.plot(pca_1['epoch']['valid'], auto_1['autoencoder']['advr_1_valid'], 'b', label='Autoencoder') ax1.plot(pca_1['epoch']['valid'], auto_1['autoencoder']['advr_2_valid'], 'b--') ax1.plot(pca_1['epoch']['valid'], pca_1['pca']['advr_1_valid'], 'g', label='PCA') ax1.plot(pca_1['epoch']['valid'], pca_2['pca']['advr_2_valid'], 'g--') ax1.plot(pca_1['epoch']['valid'], dp_1['dp']['advr_1_valid'], 'y', label='DP') ax1.plot(pca_1['epoch']['valid'], dp_1['dp']['advr_2_valid'], 'y--') # ax1.legend(prop={'size':10}) ax1.set_title(t1, y=-0.32) ax1.set_xlabel('epochs') ax1.set_ylabel('adversary log loss') ax1.grid() ax1.text(320,0.58, 'Higher is better', fontsize=14, color='r') ax1.set_ylim(bottom=0.53, top=0.81) ax1.set_xlim(left=0, right=1000) # ax2.plot(pca_1['epoch']['valid'], gan_1['encoder']['advr_2_valid'], 'r', label='EIGAN Adversary 2') # ax2.plot(pca_1['epoch']['valid'], st_1_2['encoder']['advr_2_valid'], 'k*', label='2 nodes') # ax2.plot(pca_1['epoch']['valid'], st_1_3['encoder']['advr_2_valid'], 'k--', label='3 nodes') # ax2.plot(pca_1['epoch']['valid'], st_1_4['encoder']['advr_2_valid'], 'k+', label='4 nodes') # ax2.plot(pca_1['epoch']['valid'], st_1_6['encoder']['advr_2_valid'], 'ks', label='6 nodes') # ax2.plot(pca_1['epoch']['valid'], st_1_8['encoder']['advr_2_valid'], 'k.', label='8 nodes') # ax2.plot(pca_1['epoch']['valid'], st_1_12['encoder']['advr_2_valid'], 'k', label='12 nodes') # ax2.plot(pca_1['epoch']['valid'], auto_1['autoencoder']['advr_2_valid'], 'b', label='autoencoder') # ax2.plot(pca_1['epoch']['valid'], pca_2['pca']['advr_2_valid'], 'g', label='PCA') # ax2.plot(pca_1['epoch']['valid'], dp_1['dp']['advr_2_valid'], 'y', label='DP') ax2.plot(s[0], s[2], 'r', label='encoder loss') ax2.set_title('(c)', y=-0.32) ax2.plot(np.nan, 'b', label = 'adversary loss') ax2.legend(prop={'size':10}) ax2.set_xlabel('epochs') ax2.set_ylabel('log loss') ax2.grid() ax2.set_xlim(left=0,right=500) ax4 = ax2.twinx() ax4.plot(s[0], s[6], 'b') ax4.set_ylabel('adversary loss') fig.subplots_adjust(wspace=0.4) plot_location = 'plots/{}/{}_{}.png'.format(expt, 'all', model) sep() logging.info('Saving: {}'.format(plot_location)) plt.savefig(plot_location, bbox_inches='tight', dpi=300)
def main(expt, model): pca_1 = pkl.load( open( 'checkpoints/mnist/ind_pca_training_history_01_30_2020_23_28_51.pkl', 'rb')) auto_1 = pkl.load( open( 'checkpoints/mnist/ind_autoencoder_training_history_01_30_2020_23_35_33.pkl', 'rb')) # dp_1 = pkl.load(open('checkpoints/mnist/ind_dp_training_history_02_01_2020_02_35_49.pkl', 'rb')) gan_1 = pkl.load( open( 'checkpoints/mnist/ind_gan_training_history_01_31_2020_16_05_44.pkl', 'rb')) u = pkl.load( open( 'checkpoints/mnist/eigan_training_history_02_05_2020_19_54_36_A_device_cuda_dim_1024_hidden_2048_batch_4096_epochs_501_lrencd_0.01_lrally_1e-05_lradvr_1e-05_tr_0.4023_val_1.7302.pkl', 'rb')) # print(pca_1.keys(), pca_2.keys(), auto_1.keys(), auto_2.keys(), dp_1.keys(), gan_1.keys()) # return plt.figure() fig = plt.figure(figsize=(15, 4)) ax1 = fig.add_subplot(131) ax3 = fig.add_subplot(132) ax2 = fig.add_subplot(133) t3, t1, t2 = '(b)', '(a)', '(c)' ax3.plot(pca_1['epoch']['valid'], gan_1['encoder']['advr_1_valid'], 'r') ax3.plot(pca_1['epoch']['valid'], auto_1['autoencoder']['advr_1_valid'], 'b') ax3.plot(pca_1['epoch']['valid'], pca_1['pca']['advr_1_valid'], 'g') # ax3.plot(pca_1['epoch']['valid'], dp_1['dp']['ally_valid'], 'y') ax3.legend([ 'EIGAN', 'Autoencoder', 'PCA', 'DP', ], prop={'size': 10}) ax3.set_title(t3, y=-0.32) ax3.set_xlabel('epochs') ax3.set_ylabel('ally log loss') ax3.grid() ax3.text(320, 1.68, 'Lower is better', fontsize=12, color='r') ax3.set_ylim(bottom=1.4) ax1.plot(pca_1['epoch']['valid'], gan_1['encoder']['ally_valid'], 'r--') ax1.plot(pca_1['epoch']['valid'], auto_1['autoencoder']['ally_valid'], 'b--') ax1.plot(pca_1['epoch']['valid'], pca_1['pca']['ally_valid'], 'g--') # ax1.plot(pca_1['epoch']['valid'], dp_1['dp']['advr_1_valid'], 'y--') # ax1.legend([ # 'EIGAN adversary', # 'Autoencoder adversary', # 'PCA adversary', # 'DP adversary', # ],prop={'size':10}) ax1.set_title(t1, y=-0.32) ax1.set_xlabel('epochs') ax1.set_ylabel('adversary log loss') ax1.grid() ax1.text(320, 0.57, 'Higher is better', fontsize=12, color='r') ax1.set_ylim(bottom=0.5) ax2.plot(u[0], u[2], 'r', label='encoder loss') ax2.plot(np.nan, 'b', label='adversary loss') ax4 = ax2.twinx() ax4.plot(u[0], u[6], 'b') ax2.set_title('(c)', y=-0.32) ax2.legend(prop={'size': 10}) ax2.set_xlabel('epochs') ax2.set_ylabel('encoder loss') ax2.grid() ax4.set_ylabel('adversary loss') fig.subplots_adjust(wspace=0.4) plot_location = 'plots/{}/{}_{}.png'.format(expt, 'all', model) sep() logging.info('Saving: {}'.format(plot_location)) plt.savefig(plot_location, bbox_inches='tight', dpi=300)
def main( model, time_stamp, device, ally_classes, advr_1_classes, advr_2_classes, encoding_dim, hidden_dim, leaky, test_size, batch_size, n_epochs, shuffle, lr_ally, lr_advr_1, lr_advr_2, expt, pca_ckpt, autoencoder_ckpt, encoder_ckpt, ): device = torch_device(device=device) X_normalized_train, X_normalized_valid,\ y_ally_train, y_ally_valid, \ y_advr_1_train, y_advr_1_valid, \ y_advr_2_train, y_advr_2_valid = get_data(expt, test_size) pca = joblib.load(pca_ckpt) optim = torch.optim.Adam criterionBCEWithLogits = nn.BCEWithLogitsLoss() criterionCrossEntropy = nn.CrossEntropyLoss() h = { 'epoch': { 'train': [], 'valid': [], }, 'pca': { 'ally_train': [], 'ally_valid': [], 'advr_1_train': [], 'advr_1_valid': [], 'advr_2_train': [], 'advr_2_valid': [], }, } for _ in ['pca']: if _ == 'pca': dataset_train = utils.TensorDataset( torch.Tensor(pca.eval(X_normalized_train)), torch.Tensor(y_ally_train.reshape(-1, ally_classes)), torch.Tensor(y_advr_1_train.reshape(-1, advr_1_classes)), torch.Tensor(y_advr_2_train.reshape(-1, advr_2_classes)), ) dataset_valid = utils.TensorDataset( torch.Tensor(pca.eval(X_normalized_valid)), torch.Tensor(y_ally_valid.reshape(-1, ally_classes)), torch.Tensor(y_advr_1_valid.reshape(-1, advr_1_classes)), torch.Tensor(y_advr_2_valid.reshape(-1, advr_2_classes)), ) dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=1) dataloader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=shuffle, num_workers=1) ally = DiscriminatorFCN(encoding_dim, hidden_dim, ally_classes, leaky).to(device) advr_1 = DiscriminatorFCN(encoding_dim, hidden_dim, advr_1_classes, leaky).to(device) advr_2 = DiscriminatorFCN(encoding_dim, hidden_dim, advr_2_classes, leaky).to(device) ally.apply(weights_init) advr_1.apply(weights_init) advr_2.apply(weights_init) sep('{}:{}'.format(_, 'ally')) summary(ally, input_size=(1, encoding_dim)) sep('{}:{}'.format(_, 'advr 1')) summary(advr_1, input_size=(1, encoding_dim)) sep('{}:{}'.format(_, 'advr 2')) summary(advr_2, input_size=(1, encoding_dim)) optimizer_ally = optim(ally.parameters(), lr=lr_ally) optimizer_advr_1 = optim(advr_1.parameters(), lr=lr_advr_1) optimizer_advr_2 = optim(advr_2.parameters(), lr=lr_advr_2) # adversary 1 sep("adversary 1") logging.info('{} \t {} \t {}'.format( 'Epoch', 'Advr 1 Train', 'Advr 1 Valid', )) for epoch in range(n_epochs): advr_1.train() nsamples = 0 iloss_advr = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_advr_train_torch = data[2].to(device) optimizer_advr_1.zero_grad() y_advr_train_hat_torch = advr_1(X_train_torch) loss_advr = criterionCrossEntropy( y_advr_train_hat_torch, torch.argmax(y_advr_train_torch, 1)) loss_advr.backward() optimizer_advr_1.step() nsamples += 1 iloss_advr += loss_advr.item() h[_]['advr_1_train'].append(iloss_advr / nsamples) if epoch % int(n_epochs / 10) != 0: continue advr_1.eval() nsamples = 0 iloss_advr = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = data[0].to(device) y_advr_valid_torch = data[2].to(device) y_advr_valid_hat_torch = advr_1(X_valid_torch) valid_loss_advr = criterionCrossEntropy( y_advr_valid_hat_torch, torch.argmax(y_advr_valid_torch, 1)) tmp, predicted = torch.max(y_advr_valid_hat_torch, 1) tmp, actual = torch.max(y_advr_valid_torch, 1) nsamples += 1 iloss_advr += valid_loss_advr.item() total += actual.size(0) correct += (predicted == actual).sum().item() h[_]['advr_1_valid'].append(iloss_advr / nsamples) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[_]['advr_1_train'][-1], h[_]['advr_1_valid'][-1], correct / total)) # adversary sep("adversary 2") logging.info('{} \t {} \t {}'.format( 'Epoch', 'Advr 2 Train', 'Advr 2 Valid', )) for epoch in range(n_epochs): advr_2.train() nsamples = 0 iloss_advr = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_advr_train_torch = data[3].to(device) optimizer_advr_2.zero_grad() y_advr_train_hat_torch = advr_2(X_train_torch) loss_advr = criterionBCEWithLogits(y_advr_train_hat_torch, y_advr_train_torch) loss_advr.backward() optimizer_advr_2.step() nsamples += 1 iloss_advr += loss_advr.item() h[_]['advr_2_train'].append(iloss_advr / nsamples) if epoch % int(n_epochs / 10) != 0: continue advr_2.eval() nsamples = 0 iloss_advr = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = data[0].to(device) y_advr_valid_torch = data[3].to(device) y_advr_valid_hat_torch = advr_2(X_valid_torch) valid_loss_advr = criterionBCEWithLogits( y_advr_valid_hat_torch, y_advr_valid_torch) predicted = y_advr_valid_hat_torch > 0.5 nsamples += 1 iloss_advr += valid_loss_advr.item() total += y_advr_valid_torch.size(0) correct += (predicted == y_advr_valid_torch).sum().item() h[_]['advr_2_valid'].append(iloss_advr / nsamples) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[_]['advr_2_train'][-1], h[_]['advr_2_valid'][-1], correct / total)) #ally sep("ally") logging.info('{} \t {} \t {} \t {}'.format( 'Epoch', 'Ally Train', 'Ally Valid', 'Accuracy', )) for epoch in range(n_epochs): ally.train() nsamples = 0 iloss_ally = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_ally_train_torch = data[1].to(device) optimizer_ally.zero_grad() y_ally_train_hat_torch = ally(X_train_torch) loss_ally = criterionBCEWithLogits(y_ally_train_hat_torch, y_ally_train_torch) loss_ally.backward() optimizer_ally.step() nsamples += 1 iloss_ally += loss_ally.item() if epoch not in h['epoch']['train']: h['epoch']['train'].append(epoch) h[_]['ally_train'].append(iloss_ally / nsamples) if epoch % int(n_epochs / 10) != 0: continue ally.eval() nsamples = 0 iloss_ally = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = data[0].to(device) y_ally_valid_torch = data[1].to(device) y_ally_valid_hat_torch = ally(X_valid_torch) valid_loss_ally = criterionBCEWithLogits( y_ally_valid_hat_torch, y_ally_valid_torch) predicted = y_ally_valid_hat_torch > 0.5 nsamples += 1 iloss_ally += valid_loss_ally.item() total += y_ally_valid_torch.size(0) correct += (predicted == y_ally_valid_torch).sum().item() if epoch not in h['epoch']['valid']: h['epoch']['valid'].append(epoch) h[_]['ally_valid'].append(iloss_ally / nsamples) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[_]['ally_train'][-1], h[_]['ally_valid'][-1], correct / total)) checkpoint_location = \ 'checkpoints/{}/{}_training_history_{}.pkl'.format( expt, model, time_stamp) sep() logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(h, open(checkpoint_location, 'wb'))
def main(expt, model): gan_d_128 = pkl.load( open( 'checkpoints/titanic/eigan_training_history_01_25_2020_22_26_59_F_device_cuda_dim_128_hidden_256_batch_1024_epochs_1001_lrencd_1e-05_lrally_1e-05_tr_-0.1927_val_0.6559.pkl', 'rb')) gan_d_256 = pkl.load( open( 'checkpoints/titanic/eigan_training_history_01_25_2020_22_29_45_F_device_cuda_dim_256_hidden_512_batch_1024_epochs_1001_lrencd_1e-05_lrally_1e-05_tr_-0.1852_val_0.6548.pkl', 'rb')) gan_d_512 = pkl.load( open( 'checkpoints/titanic/eigan_training_history_01_25_2020_22_32_52_F_device_cuda_dim_512_hidden_1024_batch_1024_epochs_1001_lrencd_1e-05_lrally_1e-05_tr_-0.1820_val_0.6553.pkl', 'rb')) gan_d_1024 = pkl.load( open( 'checkpoints/titanic/eigan_training_history_01_25_2020_22_36_17_F_device_cuda_dim_1024_hidden_2048_batch_1024_epochs_1001_lrencd_1e-05_lrally_1e-05_tr_-0.1834_val_0.6484.pkl', 'rb')) gan_d_2048 = pkl.load( open( 'checkpoints/titanic/eigan_training_history_01_25_2020_22_40_32_F_device_cuda_dim_2048_hidden_4086_batch_1024_epochs_1001_lrencd_1e-05_lrally_1e-05_tr_-0.1826_val_0.6424.pkl', 'rb')) # print(pca_1.keys(), pca_2.keys(), auto_1.keys(), auto_2.keys(), dp_1.keys(), gan_1.keys()) # return plt.figure() fig = plt.figure(figsize=(15, 5)) ax3 = fig.add_subplot(131) ax1 = fig.add_subplot(132) ax2 = fig.add_subplot(133) t3, t1, t2 = '(a)', '(b)', '(c)' ax3.plot(pca_1['epoch']['valid'], gan_1['encoder']['ally_valid'], 'r') ax3.plot(pca_1['epoch']['valid'], pca_1['pca']['ally_valid'], 'g') ax3.plot(pca_1['epoch']['valid'], auto_1['autoencoder']['ally_valid'], 'b') ax3.plot(pca_1['epoch']['valid'], dp_1['dp']['ally_valid'], 'y') ax3.legend([ 'gan ally', 'autoencoder ally', 'pca ally', 'dp ally', ]) ax3.set_title(t3, y=-0.2) ax3.set_xlabel('iterations (scale adjusted)') ax3.set_ylabel('loss') ax1.plot(pca_1['epoch']['valid'], gan_1['encoder']['advr_1_valid'], 'r--') ax1.plot(pca_1['epoch']['valid'], auto_1['autoencoder']['advr_1_valid'], 'b--') ax1.plot(pca_1['epoch']['valid'], pca_1['pca']['advr_1_valid'], 'g--') ax1.plot(pca_1['epoch']['valid'], dp_1['dp']['advr_1_valid'], 'y--') ax1.legend([ 'gan adversary 1', 'autoencoder adversary 1', 'pca adversary 1', 'dp adversary 1', ]) ax1.set_title(t1, y=-0.2) ax1.set_xlabel('iterations (scale adjusted)') ax1.set_ylabel('loss') ax2.plot(pca_1['epoch']['valid'], gan_1['encoder']['advr_2_valid'], 'r--') ax2.plot(pca_1['epoch']['valid'], auto_2['autoencoder']['advr_2_valid'], 'b--') ax2.plot(pca_1['epoch']['valid'], pca_2['pca']['advr_2_valid'], 'g--') ax2.plot(pca_1['epoch']['valid'], dp_1['dp']['advr_2_valid'], 'y--') ax2.legend([ 'gan adversary 2', 'autoencoder adversary 2', 'pca adversary 2', 'dp adversary 2', ]) ax2.set_title(t2, y=-0.2) ax2.set_xlabel('iterations (scale adjusted)') ax2.set_ylabel('loss') plot_location = 'plots/{}/{}_{}_b4096.png'.format(expt, 'all', model) sep() logging.info('Saving: {}'.format(plot_location)) plt.savefig(plot_location, bbox_inches='tight')
model = 'ind_gan' marker = 'H' pr_time, fl_time = time_stp() logger(expt, model, fl_time, marker) log_time('Start', pr_time) args = comparison_argparse() main(model=model, time_stamp=fl_time, device=args['device'], ally_classes=args['n_ally'], advr_1_classes=args['n_advr_1'], advr_2_classes=args['n_advr_2'], encoding_dim=args['dim'], hidden_dim=args['hidden_dim'], leaky=args['leaky'], test_size=args['test_size'], batch_size=args['batch_size'], n_epochs=args['n_epochs'], shuffle=args['shuffle'] == 1, lr_ally=args['lr_ally'], lr_advr_1=args['lr_advr_1'], lr_advr_2=args['lr_advr_2'], expt=args['expt'], pca_ckpt=args['pca_ckpt'], autoencoder_ckpt=args['autoencoder_ckpt'], encoder_ckpt=args['encoder_ckpt']) log_time('End', time_stp()[0]) sep()
def main( model, time_stamp, device, ngpu, encoding_dim, hidden_dim, leaky, activation, test_size, batch_size, n_epochs, shuffle, init_weight, lr_encd, lr_ally, lr_advr, alpha, expt, num_allies, marker ): device = torch_device(device=device) X, targets = load_processed_data( expt, 'processed_data_X_targets.pkl') log_shapes( [X] + [targets[i] for i in targets], locals(), 'Dataset loaded' ) targets = {i: elem.reshape(-1, 1) for i, elem in targets.items()} X_train, X_valid, \ y_hef_train, y_hef_valid, \ y_exf_train, y_exf_valid, \ y_gdr_train, y_gdr_valid, \ y_lan_train, y_lan_valid, \ y_mar_train, y_mar_valid, \ y_rel_train, y_rel_valid, \ y_ins_train, y_ins_valid, \ y_dis_train, y_dis_valid, \ y_adl_train, y_adl_valid, \ y_adt_train, y_adt_valid, \ y_etn_train, y_etn_valid = train_test_split( X, targets['hospital_expire_flag'], targets['expire_flag'], targets['gender'], targets['language'], targets['marital_status'], targets['religion'], targets['insurance'], targets['discharge_location'], targets['admission_location'], targets['admission_type'], targets['ethnicity'], test_size=test_size, stratify=pd.DataFrame(np.concatenate( ( targets['admission_type'], ), axis=1) ) ) log_shapes( [ X_train, X_valid, y_hef_train, y_hef_valid, y_exf_train, y_exf_valid, y_gdr_train, y_gdr_valid, y_lan_train, y_lan_valid, y_mar_train, y_mar_valid, y_rel_train, y_rel_valid, y_ins_train, y_ins_valid, y_dis_train, y_dis_valid, y_adl_train, y_adl_valid, y_adt_train, y_adt_valid, y_etn_train, y_etn_valid ], locals(), 'Data size after train test split' ) y_ally_trains = [ y_hef_train, y_exf_train, y_gdr_train, y_lan_train, y_mar_train, y_rel_train, y_ins_train, y_dis_train, y_adl_train, y_etn_train, ] y_ally_valids = [ y_hef_valid, y_exf_valid, y_gdr_valid, y_lan_valid, y_mar_valid, y_rel_valid, y_ins_valid, y_dis_valid, y_adl_valid, y_etn_valid, ] y_advr_train = y_adt_train y_advr_valid = y_adt_valid scaler = StandardScaler() X_normalized_train = scaler.fit_transform(X_train) X_normalized_valid = scaler.transform(X_valid) log_shapes([X_normalized_train, X_normalized_valid], locals()) for i in [num_allies-1]: sep('NUMBER OF ALLIES: {}'.format(i+1)) encoder = GeneratorFCN( X_normalized_train.shape[1], hidden_dim, encoding_dim, leaky, activation).to(device) ally = {} for j in range(i+1): ally[j] = DiscriminatorFCN( encoding_dim, hidden_dim, 1, leaky).to(device) advr = DiscriminatorFCN( encoding_dim, hidden_dim, 1, leaky).to(device) if init_weight: sep() logging.info('applying weights_init ...') encoder.apply(weights_init) for j in range(i+1): ally[j].apply(weights_init) advr.apply(weights_init) sep('encoder') summary(encoder, input_size=(1, X_normalized_train.shape[1])) for j in range(i+1): sep('ally:{}'.format(j)) summary(ally[j], input_size=(1, encoding_dim)) sep('advr') summary(advr, input_size=(1, encoding_dim)) optim = torch.optim.Adam criterionBCEWithLogits = nn.BCEWithLogitsLoss() optimizer_encd = optim(encoder.parameters(), lr=lr_encd) optimizer_ally = {} for j in range(i+1): optimizer_ally[j] = optim(ally[j].parameters(), lr=lr_ally) optimizer_advr = optim(advr.parameters(), lr=lr_advr) dataset_train = utils.TensorDataset( torch.Tensor(X_normalized_train), torch.Tensor(y_advr_train), ) for y_ally_train in y_ally_trains: dataset_train.tensors = (*dataset_train.tensors, torch.Tensor(y_ally_train)) dataloader_train = torch.utils.data.DataLoader( dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=1 ) dataset_valid = utils.TensorDataset( torch.Tensor(X_normalized_valid), torch.Tensor(y_advr_valid), ) for y_ally_valid in y_ally_valids: dataset_valid.tensors = (*dataset_valid.tensors, torch.Tensor(y_ally_valid)) dataloader_valid = torch.utils.data.DataLoader( dataset_valid, batch_size=batch_size, shuffle=shuffle, num_workers=1 ) epochs_train = [] epochs_valid = [] encd_loss_train = [] encd_loss_valid = [] ally_loss_train = {} ally_loss_valid = {} for j in range(i+1): ally_loss_train[j] = [] ally_loss_valid[j] = [] advr_loss_train = [] advr_loss_valid = [] log_list = ['epoch', 'encd_train', 'encd_valid', 'advr_train', 'advr_valid'] + \ ['ally_{}_train \t ally_{}_valid'.format(str(j), str(j)) for j in range(i+1)] logging.info(' \t '.join(log_list)) for epoch in range(n_epochs): encoder.train() for j in range(i+1): ally[i].eval() advr.eval() nsamples = 0 iloss = 0 for data in dataloader_train: X_train_torch = data[0].to(device) y_advr_train_torch = data[1].to(device) y_ally_train_torch = {} for j in range(i+1): y_ally_train_torch[j] = data[j+2].to(device) optimizer_encd.zero_grad() # Forward pass X_train_encoded = encoder(X_train_torch) y_advr_train_hat_torch = advr(X_train_encoded) y_ally_train_hat_torch = {} for j in range(i+1): y_ally_train_hat_torch[j] = ally[j](X_train_encoded) # Compute Loss loss_ally = {} for j in range(i+1): loss_ally[j] = criterionBCEWithLogits( y_ally_train_hat_torch[j], y_ally_train_torch[j]) loss_advr = criterionBCEWithLogits( y_advr_train_hat_torch, y_advr_train_torch) loss_encd = alpha/num_allies * sum([loss_ally[_].item() for _ in loss_ally]) - (1-alpha) * loss_advr # Backward pass loss_encd.backward() optimizer_encd.step() nsamples += 1 iloss += loss_encd.item() epochs_train.append(epoch) encd_loss_train.append(iloss/nsamples) encoder.eval() for j in range(i+1): ally[j].train() advr.train() nsamples = 0 iloss_ally = {} for j in range(i+1): iloss_ally[j] = 0 iloss_advr = 0 for data in dataloader_train: X_train_torch = data[0].to(device) y_advr_train_torch = data[1].to(device) y_ally_train_torch = {} for j in range(i+1): y_ally_train_torch[j] = data[j+2].to(device) y_ally_train_hat_torch = {} loss_ally = {} for j in range(i+1): optimizer_ally[j].zero_grad() X_train_encoded = encoder(X_train_torch) y_ally_train_hat_torch[j] = ally[j](X_train_encoded) loss_ally[j] = criterionBCEWithLogits( y_ally_train_hat_torch[j], y_ally_train_torch[j]) loss_ally[j].backward() optimizer_ally[j].step() optimizer_advr.zero_grad() X_train_encoded = encoder(X_train_torch) y_advr_train_hat_torch = advr(X_train_encoded) loss_advr = criterionBCEWithLogits( y_advr_train_hat_torch, y_advr_train_torch) loss_advr.backward() optimizer_advr.step() nsamples += 1 for j in range(i+1): iloss_ally[j] += loss_ally[j].item() iloss_advr += loss_advr.item() for j in range(i+1): ally_loss_train[j].append(iloss_ally[j]/nsamples) advr_loss_train.append(iloss_advr/nsamples) if epoch % int(n_epochs/10) != 0: continue encoder.eval() for j in range(i+1): ally[j].eval() advr.eval() nsamples = 0 iloss = 0 iloss_ally = {} for j in range(i+1): iloss_ally[j] = 0 iloss_advr = 0 for data in dataloader_valid: X_valid_torch = data[0].to(device) y_advr_valid_torch = data[1].to(device) y_ally_valid_torch = {} for j in range(i+1): y_ally_valid_torch[j] = data[j+2].to(device) X_valid_encoded = encoder(X_valid_torch) y_ally_valid_hat_torch = {} for j in range(i+1): y_ally_valid_hat_torch[j] = ally[j](X_valid_encoded) y_advr_valid_hat_torch = advr(X_valid_encoded) valid_loss_ally = {} for j in range(i+1): valid_loss_ally[j] = criterionBCEWithLogits( y_ally_valid_hat_torch[j], y_ally_valid_torch[j]) valid_loss_advr = criterionBCEWithLogits( y_advr_valid_hat_torch, y_advr_valid_torch) valid_loss_encd = alpha/num_allies*sum( [valid_loss_ally[_].item() for _ in valid_loss_ally] ) - (1-alpha)* valid_loss_advr nsamples += 1 iloss += valid_loss_encd.item() for j in range(i+1): iloss_ally[j] += valid_loss_ally[j].item() iloss_advr += valid_loss_advr.item() epochs_valid.append(epoch) encd_loss_valid.append(iloss/nsamples) for j in range(i+1): ally_loss_valid[j].append(iloss_ally[j]/nsamples) advr_loss_valid.append(iloss_advr/nsamples) log_line = [str(epoch), '{:.8f}'.format(encd_loss_train[-1]), '{:.8f}'.format(encd_loss_valid[-1]), '{:.8f}'.format(advr_loss_train[-1]), '{:.8f}'.format(advr_loss_valid[-1]), ] + \ [ '{:.8f} \t {:.8f}'.format( ally_loss_train[_][-1], ally_loss_valid[_][-1] ) for _ in ally_loss_train] logging.info(' \t '.join(log_line)) config_summary = '{}_n_{}_device_{}_dim_{}_hidden_{}_batch_{}_epochs_{}_ally_{}_encd_{:.4f}_advr_{:.4f}'\ .format( marker, num_allies, device, encoding_dim, hidden_dim, batch_size, n_epochs, i, encd_loss_train[-1], advr_loss_valid[-1], ) plt.figure() plt.plot(epochs_train, encd_loss_train, 'r', label='encd train') plt.plot(epochs_valid, encd_loss_valid, 'r--', label='encd valid') # sum_loss = [0] * len(ally_loss_train[0]) # for j in range(i+1): # for k in ally_loss_valid: # sum_loss[j] += ally_loss_train[j] # sum_loss = [sum_loss[j]/len(ally_loss_valid) for j in sum_loss] # plt.plot( # epochs_train, # sum([ally_loss_train[j] for j in range(i+1)])/len(ally_loss_train), # 'b', label='ally_sum_train') # sum_loss = [0] * len(ally_loss_valid) # for j in range(i+1): # sum_loss[j] += ally_loss_valid[j] # sum_loss = [sum_loss[j]/len(ally_loss_valid) for j in sum_loss] # plt.plot(epochs_valid, sum_loss, 'b', label='ally_sum_valid') plt.plot(epochs_train, advr_loss_train, 'g', label='advr_train') plt.plot(epochs_valid, advr_loss_valid, 'g--', label='advr_valid') plt.legend() plt.title("{} on {} training".format(model, expt)) plot_location = 'plots/{}/{}_training_{}_{}.png'.format( expt, model, time_stamp, config_summary) sep() logging.info('Saving: {}'.format(plot_location)) plt.savefig(plot_location) checkpoint_location = \ 'checkpoints/{}/{}_training_history_{}_{}.pkl'.format( expt, model, time_stamp, config_summary) logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(( epochs_train, epochs_valid, encd_loss_train, encd_loss_valid, ally_loss_train, ally_loss_valid, advr_loss_train, advr_loss_valid, ), open(checkpoint_location, 'wb')) model_ckpt = 'checkpoints/{}/{}_torch_model_{}_{}.pkl'.format( expt, model, time_stamp, config_summary) logging.info('Saving: {}'.format(model_ckpt)) torch.save(encoder, model_ckpt)
def main( model, time_stamp, device, ally_classes, advr_1_classes, advr_2_classes, encoding_dim, test_size, batch_size, n_epochs, shuffle, lr, expt, ): device = torch_device(device=device) # refer to PrivacyGAN_Titanic for data preparation X, y_ally, y_advr_1, y_advr_2 = load_processed_data( expt, 'processed_data_X_y_ally_y_advr_y_advr_2.pkl') log_shapes( [X, y_ally, y_advr_1, y_advr_2], locals(), 'Dataset loaded' ) X_train, X_valid = train_test_split( X, test_size=test_size, stratify=pd.DataFrame(np.concatenate( ( y_ally.reshape(-1, ally_classes), y_advr_1.reshape(-1, advr_1_classes), y_advr_2.reshape(-1, advr_2_classes), ), axis=1) ) ) log_shapes( [ X_train, X_valid, ], locals(), 'Data size after train test split' ) scaler = StandardScaler() X_train_normalized = scaler.fit_transform(X_train) X_valid_normalized = scaler.transform(X_valid) log_shapes([X_train_normalized, X_valid_normalized], locals()) dataset_train = utils.TensorDataset(torch.Tensor(X_train_normalized)) dataloader_train = torch.utils.data.DataLoader( dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=2) dataset_valid = utils.TensorDataset(torch.Tensor(X_valid_normalized)) dataloader_valid = torch.utils.data.DataLoader( dataset_valid, batch_size=batch_size, shuffle=False, num_workers=2) auto_encoder = AutoEncoderBasic( input_size=X_train_normalized.shape[1], encoding_dim=encoding_dim ).to(device) criterion = torch.nn.MSELoss() adam_optim = torch.optim.Adam optimizer = adam_optim(auto_encoder.parameters(), lr=lr) summary(auto_encoder, input_size=(1, X_train_normalized.shape[1])) h_epoch = [] h_valid = [] h_train = [] auto_encoder.train() sep() logging.info("epoch \t Aencoder_train \t Aencoder_valid") for epoch in range(n_epochs): nsamples = 0 iloss = 0 for data in dataloader_train: optimizer.zero_grad() X_torch = data[0].to(device) X_torch_hat = auto_encoder(X_torch) loss = criterion(X_torch_hat, X_torch) loss.backward() optimizer.step() nsamples += 1 iloss += loss.item() if epoch % int(n_epochs/10) != 0: continue h_epoch.append(epoch) h_train.append(iloss/nsamples) nsamples = 0 iloss = 0 for data in dataloader_valid: X_torch = data[0].to(device) X_torch_hat = auto_encoder(X_torch) loss = criterion(X_torch_hat, X_torch) nsamples += 1 iloss += loss.item() h_valid.append(iloss/nsamples) logging.info('{} \t {:.8f} \t {:.8f}'.format( h_epoch[-1], h_train[-1], h_valid[-1], )) config_summary = 'device_{}_dim_{}_batch_{}_epochs_{}_lr_{}_tr_{:.4f}_val_{:.4f}'\ .format( device, encoding_dim, batch_size, n_epochs, lr, h_train[-1], h_valid[-1], ) plt.plot(h_epoch, h_train, 'r--') plt.plot(h_epoch, h_valid, 'b--') plt.legend(['train_loss', 'valid_loss']) plt.title("autoencoder training {}".format(config_summary)) plot_location = 'plots/{}/{}_training_{}_{}.png'.format( expt, model, time_stamp, config_summary) sep() logging.info('Saving: {}'.format(plot_location)) plt.savefig(plot_location) checkpoint_location = \ 'checkpoints/{}/{}_training_history_{}_{}.pkl'.format( expt, model, time_stamp, config_summary) logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump((h_epoch, h_train, h_valid), open(checkpoint_location, 'wb')) model_ckpt = 'checkpoints/{}/{}_torch_model_{}_{}.pkl'.format( expt, model, time_stamp, config_summary) logging.info('Saving: {}'.format(model_ckpt)) torch.save(auto_encoder, model_ckpt)
def main( model, time_stamp, device, ally_classes, advr_1_classes, advr_2_classes, encoding_dim, hidden_dim, leaky, test_size, batch_size, n_epochs, shuffle, lr, expt, pca_ckpt, autoencoder_ckpt, encoder_ckpt, ): device = torch_device(device=device) X, targets = load_processed_data(expt, 'processed_data_X_targets.pkl') log_shapes([X] + [targets[i] for i in targets], locals(), 'Dataset loaded') h = {} for name, target in targets.items(): sep(name) target = target.reshape(-1, 1) X_train, X_valid, \ y_train, y_valid = train_test_split( X, target, test_size=test_size, stratify=target ) log_shapes([ X_train, X_valid, y_train, y_valid, ], locals(), 'Data size after train test split') scaler = StandardScaler() X_normalized_train = scaler.fit_transform(X_train) X_normalized_valid = scaler.transform(X_valid) log_shapes([X_normalized_train, X_normalized_valid], locals()) optim = torch.optim.Adam criterionBCEWithLogits = nn.BCEWithLogitsLoss() h[name] = { 'epoch_train': [], 'epoch_valid': [], 'y_train': [], 'y_valid': [], } dataset_train = utils.TensorDataset( torch.Tensor(X_normalized_train), torch.Tensor(y_train.reshape(-1, 1))) dataset_valid = utils.TensorDataset( torch.Tensor(X_normalized_valid), torch.Tensor(y_valid.reshape(-1, 1)), ) dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=1) dataloader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=shuffle, num_workers=1) clf = DiscriminatorFCN(encoding_dim, hidden_dim, 1, leaky).to(device) clf.apply(weights_init) sep('{}:{}'.format(name, 'summary')) summary(clf, input_size=(1, encoding_dim)) optimizer = optim(clf.parameters(), lr=lr) # adversary 1 sep("TRAINING") logging.info('{} \t {} \t {}'.format( 'Epoch', 'Train', 'Valid', )) for epoch in range(n_epochs): clf.train() nsamples = 0 iloss_train = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = data[0].to(device) y_train_torch = data[1].to(device) optimizer.zero_grad() y_train_hat_torch = clf(X_train_torch) loss_train = criterionBCEWithLogits(y_train_hat_torch, y_train_torch) loss_train.backward() optimizer.step() nsamples += 1 iloss_train += loss_train.item() h[name]['y_train'].append(iloss_train / nsamples) h[name]['epoch_train'].append(epoch) if epoch % int(n_epochs / 10) != 0: continue clf.eval() nsamples = 0 iloss_valid = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = data[0].to(device) y_valid_torch = data[1].to(device) y_valid_hat_torch = clf(X_valid_torch) valid_loss = criterionBCEWithLogits( y_valid_hat_torch, y_valid_torch, ) predicted = y_valid_hat_torch > 0.5 nsamples += 1 iloss_valid += valid_loss.item() total += y_valid_torch.size(0) correct += (predicted == y_valid_torch).sum().item() h[name]['y_valid'].append(iloss_valid / nsamples) h[name]['epoch_valid'].append(epoch) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[name]['y_train'][-1], h[name]['y_valid'][-1], correct / total)) checkpoint_location = \ 'checkpoints/{}/{}_training_history_{}.pkl'.format( expt, model, time_stamp) sep() logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(h, open(checkpoint_location, 'wb'))
def main(expt, model): pca_1 = pkl.load( open( 'checkpoints/titanic/ind_pca_training_history_01_15_2020_23_25_44.pkl', 'rb')) pca_2 = pkl.load( open( 'checkpoints/titanic/ind_pca_training_history_01_15_2020_23_45_00.pkl', 'rb')) auto_1 = pkl.load( open( 'checkpoints/titanic/ind_autoencoder_training_history_01_16_2020_03_53_53.pkl', 'rb')) auto_2 = pkl.load( open( 'checkpoints/titanic/ind_autoencoder_training_history_01_16_2020_04_30_49.pkl', 'rb')) dp_1 = pkl.load( open( 'checkpoints/titanic/ind_dp_training_history_01_30_2020_14_11_06.pkl', 'rb')) gan_1 = pkl.load( open( 'checkpoints/titanic/ind_gan_training_history_01_16_2020_21_56_04.pkl', 'rb')) s = pkl.load( open( 'checkpoints/titanic/eigan_training_history_01_15_2020_22_43_42_E_device_cuda_dim_1400_hidden_2800_batch_1024_epochs_1001_lrencd_1e-05_lrally_1e-05_tr_-0.1852_val_0.6462.pkl', 'rb')) # checkpoints/titanic/ind_gan_training_history_01_16_2020_21_56_04.pkl # print(pca_1.keys(), pca_2.keys(), auto_1.keys(), auto_2.keys(), dp_1.keys(), gan_1.keys()) # return plt.figure() fig = plt.figure(figsize=(15, 3)) ax3 = fig.add_subplot(131) ax1 = fig.add_subplot(132) ax2 = fig.add_subplot(133) t3, t1, t2 = '(a)', '(b)', '(c)' ax3.plot(pca_1['epoch']['valid'], gan_1['encoder']['ally_valid'], 'r') ax3.plot(pca_1['epoch']['valid'], pca_1['pca']['ally_valid'], 'g') ax3.plot(pca_1['epoch']['valid'], auto_1['autoencoder']['ally_valid'], 'b') ax3.plot(pca_1['epoch']['valid'], dp_1['dp']['ally_valid'], 'y') ax3.legend([ 'EIGAN ally', 'Autoencoder ally', 'PCA ally', 'DP ally', ], prop={'size': 10}) ax3.set_title(t3, y=-0.32) ax3.set_xlabel('epochs') ax3.set_ylabel('log loss') ax3.grid() ax3.set_xlim(left=0, right=1000) ax3.text(320, 0.67, 'Lower is better', fontsize=12, color='r') ax1.plot(pca_1['epoch']['valid'], gan_1['encoder']['advr_1_valid'], 'r', label='EIGAN adversary') ax1.plot(pca_1['epoch']['valid'], auto_1['autoencoder']['advr_1_valid'], 'b', label='Autoencoder adversary') ax1.plot(pca_1['epoch']['valid'], pca_1['pca']['advr_1_valid'], 'g', label='PCA adversary') ax1.plot(pca_1['epoch']['valid'], dp_1['dp']['advr_1_valid'], 'y', label='DP adversary') ax1.plot(pca_1['epoch']['valid'], gan_1['encoder']['advr_2_valid'], 'r--') ax1.plot(pca_1['epoch']['valid'], auto_2['autoencoder']['advr_2_valid'], 'b--') ax1.plot(pca_1['epoch']['valid'], pca_2['pca']['advr_2_valid'], 'g--') ax1.plot(pca_1['epoch']['valid'], dp_1['dp']['advr_2_valid'], 'y--') ax1.legend(prop={'size': 10}) ax1.set_title(t1, y=-0.32) ax1.set_xlabel('epochs') ax1.set_ylabel('log loss') ax1.grid() ax1.set_xlim(left=0, right=1000) ax1.text(320, 0.66, 'Higher is better', fontsize=12, color='r') ax2.plot(s[0], s[2], 'r', label='encoder loss') ax2.set_title('(c)', y=-0.32) ax2.plot(np.nan, 'b', label='adversary loss') ax2.legend(prop={'size': 10}) ax2.set_xlabel('epochs') ax2.set_ylabel('encoder loss') ax2.grid() ax2.set_xlim(left=0, right=1000) ax4 = ax2.twinx() ax4.plot(s[0], s[6], 'b') ax4.set_ylabel('adversary loss') fig.subplots_adjust(wspace=0.3) plot_location = 'plots/{}/{}_{}_.png'.format(expt, 'all', model) sep() logging.info('Saving: {}'.format(plot_location)) plt.savefig(plot_location, bbox_inches='tight', dpi=300)
def main( model, time_stamp, device, ally_classes, advr_1_classes, advr_2_classes, encoding_dim, hidden_dim, leaky, test_size, batch_size, n_epochs, shuffle, lr_ally, lr_advr_1, lr_advr_2, expt, pca_ckpt, autoencoder_ckpt, encoder_ckpt, ): device = torch_device(device=device) X, y_ally, y_advr_1, y_advr_2 = load_processed_data( expt, 'processed_data_X_y_ally_y_advr_y_advr_2.pkl') log_shapes([X, y_ally, y_advr_1, y_advr_2], locals(), 'Dataset loaded') X_train, X_valid, \ y_ally_train, y_ally_valid, \ y_advr_1_train, y_advr_1_valid, \ y_advr_2_train, y_advr_2_valid = train_test_split( X, y_ally, y_advr_1, y_advr_2, test_size=test_size, stratify=pd.DataFrame(np.concatenate( ( y_ally.reshape(-1, ally_classes), y_advr_1.reshape(-1, advr_1_classes), y_advr_2.reshape(-1, advr_2_classes), ), axis=1) ) ) log_shapes([ X_train, X_valid, y_ally_train, y_ally_valid, y_advr_1_train, y_advr_1_valid, y_advr_2_train, y_advr_2_valid, ], locals(), 'Data size after train test split') scaler = StandardScaler() X_normalized_train = scaler.fit_transform(X_train) X_normalized_valid = scaler.transform(X_valid) log_shapes([X_normalized_train, X_normalized_valid], locals()) encoder = torch.load(encoder_ckpt) encoder.eval() optim = torch.optim.Adam criterionBCEWithLogits = nn.BCEWithLogitsLoss() criterionCrossEntropy = nn.CrossEntropyLoss() h = { 'epoch': { 'train': [], 'valid': [], }, 'encoder': { 'ally_train': [], 'ally_valid': [], 'advr_1_train': [], 'advr_1_valid': [], 'advr_2_train': [], 'advr_2_valid': [], }, } for _ in ['encoder']: dataset_train = utils.TensorDataset( torch.Tensor(X_normalized_train), torch.Tensor(y_ally_train.reshape(-1, ally_classes)), torch.Tensor(y_advr_1_train.reshape(-1, advr_1_classes)), torch.Tensor(y_advr_2_train.reshape(-1, advr_2_classes)), ) dataset_valid = utils.TensorDataset( torch.Tensor(X_normalized_valid), torch.Tensor(y_ally_valid.reshape(-1, ally_classes)), torch.Tensor(y_advr_1_valid.reshape(-1, advr_1_classes)), torch.Tensor(y_advr_2_valid.reshape(-1, advr_2_classes)), ) def transform(input_arg): return encoder(input_arg) dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=1) dataloader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=shuffle, num_workers=1) ally = DiscriminatorFCN(encoding_dim, hidden_dim, ally_classes, leaky).to(device) advr_1 = DiscriminatorFCN(encoding_dim, hidden_dim, advr_1_classes, leaky).to(device) advr_2 = DiscriminatorFCN(encoding_dim, hidden_dim, advr_2_classes, leaky).to(device) ally.apply(weights_init) advr_1.apply(weights_init) advr_2.apply(weights_init) sep('{}:{}'.format(_, 'ally')) summary(ally, input_size=(1, encoding_dim)) sep('{}:{}'.format(_, 'advr 1')) summary(advr_1, input_size=(1, encoding_dim)) sep('{}:{}'.format(_, 'advr 2')) summary(advr_2, input_size=(1, encoding_dim)) optimizer_ally = optim(ally.parameters(), lr=lr_ally) optimizer_advr_1 = optim(advr_1.parameters(), lr=lr_advr_1) optimizer_advr_2 = optim(advr_2.parameters(), lr=lr_advr_2) # adversary 1 sep("adversary 1") logging.info('{} \t {} \t {}'.format( 'Epoch', 'Advr 1 Train', 'Advr 1 Valid', )) for epoch in range(n_epochs): advr_1.train() nsamples = 0 iloss_advr = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = transform(data[0].to(device)) y_advr_train_torch = data[2].to(device) optimizer_advr_1.zero_grad() y_advr_train_hat_torch = advr_1(X_train_torch) loss_advr = criterionBCEWithLogits(y_advr_train_hat_torch, y_advr_train_torch) loss_advr.backward() optimizer_advr_1.step() nsamples += 1 iloss_advr += loss_advr.item() h[_]['advr_1_train'].append(iloss_advr / nsamples) if epoch % int(n_epochs / 10) != 0: continue advr_1.eval() nsamples = 0 iloss_advr = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = transform(data[0].to(device)) y_advr_valid_torch = data[2].to(device) y_advr_valid_hat_torch = advr_1(X_valid_torch) valid_loss_advr = criterionBCEWithLogits( y_advr_valid_hat_torch, y_advr_valid_torch, ) predicted = y_advr_valid_hat_torch > 0.5 nsamples += 1 iloss_advr += valid_loss_advr.item() total += y_advr_valid_torch.size(0) correct += (predicted == y_advr_valid_torch).sum().item() h[_]['advr_1_valid'].append(iloss_advr / nsamples) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[_]['advr_1_train'][-1], h[_]['advr_1_valid'][-1], correct / total)) # adversary sep("adversary 2") logging.info('{} \t {} \t {}'.format( 'Epoch', 'Advr 2 Train', 'Advr 2 Valid', )) for epoch in range(n_epochs): advr_2.train() nsamples = 0 iloss_advr = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = transform(data[0].to(device)) y_advr_train_torch = data[3].to(device) optimizer_advr_2.zero_grad() y_advr_train_hat_torch = advr_2(X_train_torch) loss_advr = criterionBCEWithLogits(y_advr_train_hat_torch, y_advr_train_torch) loss_advr.backward() optimizer_advr_2.step() nsamples += 1 iloss_advr += loss_advr.item() h[_]['advr_2_train'].append(iloss_advr / nsamples) if epoch % int(n_epochs / 10) != 0: continue advr_2.eval() nsamples = 0 iloss_advr = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = transform(data[0].to(device)) y_advr_valid_torch = data[3].to(device) y_advr_valid_hat_torch = advr_2(X_valid_torch) valid_loss_advr = criterionBCEWithLogits( y_advr_valid_hat_torch, y_advr_valid_torch) predicted = y_advr_valid_hat_torch > 0.5 nsamples += 1 iloss_advr += valid_loss_advr.item() total += y_advr_valid_torch.size(0) correct += (predicted == y_advr_valid_torch).sum().item() h[_]['advr_2_valid'].append(iloss_advr / nsamples) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[_]['advr_2_train'][-1], h[_]['advr_2_valid'][-1], correct / total)) sep("ally") logging.info('{} \t {} \t {}'.format( 'Epoch', 'Ally Train', 'Ally Valid', )) for epoch in range(n_epochs): ally.train() nsamples = 0 iloss_ally = 0 for i, data in enumerate(dataloader_train, 0): X_train_torch = transform(data[0].to(device)) y_ally_train_torch = data[1].to(device) optimizer_ally.zero_grad() y_ally_train_hat_torch = ally(X_train_torch) loss_ally = criterionBCEWithLogits(y_ally_train_hat_torch, y_ally_train_torch) loss_ally.backward() optimizer_ally.step() nsamples += 1 iloss_ally += loss_ally.item() if epoch not in h['epoch']['train']: h['epoch']['train'].append(epoch) h[_]['ally_train'].append(iloss_ally / nsamples) if epoch % int(n_epochs / 10) != 0: continue ally.eval() nsamples = 0 iloss_ally = 0 correct = 0 total = 0 for i, data in enumerate(dataloader_valid, 0): X_valid_torch = transform(data[0].to(device)) y_ally_valid_torch = data[1].to(device) y_ally_valid_hat_torch = ally(X_valid_torch) valid_loss_ally = criterionBCEWithLogits( y_ally_valid_hat_torch, y_ally_valid_torch) predicted = y_ally_valid_hat_torch > 0.5 nsamples += 1 iloss_ally += valid_loss_ally.item() total += y_ally_valid_torch.size(0) correct += (predicted == y_ally_valid_torch).sum().item() if epoch not in h['epoch']['valid']: h['epoch']['valid'].append(epoch) h[_]['ally_valid'].append(iloss_ally / nsamples) logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format( epoch, h[_]['ally_train'][-1], h[_]['ally_valid'][-1], correct / total)) checkpoint_location = \ 'checkpoints/{}/{}_training_history_{}.pkl'.format( expt, model, time_stamp) sep() logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(h, open(checkpoint_location, 'wb'))
def main( model, time_stamp, device, ally_classes, advr_classes, batch_size, n_epochs, shuffle, init_weight, lr_encd, lr_ally, lr_advr, alpha, expt, encoder_ckpt, ally_ckpts, advr_ckpts, marker ): device = torch_device(device=device) encoder = define_G(cfg.num_channels[expt], cfg.num_channels[expt], 64, gpu_id=device) encoder.load_state_dict(torch.load(encoder_ckpt)) sep() logging.info("Loaded: {}".format(encoder_ckpt)) allies = [Net(num_classes=_).to(device) for _ in ally_classes] advrs = [Net(num_classes=_).to(device) for _ in advr_classes] for ally, ckpt in zip(allies, ally_ckpts): logging.info("Loaded: {}".format(ckpt)) ally.load_state_dict(torch.load(ckpt)) for advr, ckpt in zip(advrs, advr_ckpts): logging.info("Loaded: {}".format(ckpt)) advr.load_state_dict(torch.load(ckpt)) sep() optim = torch.optim.Adam criterionNLL = nn.NLLLoss() optimizer_encd = optim(encoder.parameters(), lr=lr_encd) optimizer_ally = [optim(ally.parameters(), lr=lr) for lr, ally in zip(lr_ally, allies)] optimizer_advr = [optim(advr.parameters(), lr=lr) for lr, advr in zip(lr_advr, advrs)] dataloader_train = get_loader(expt, batch_size, True) dataloader_valid = get_loader(expt, batch_size, False) epochs_train = [] epochs_valid = [] encd_loss_train = [] encd_loss_valid = [] ally_loss_train = [] ally_loss_valid = [] advr_loss_train = [] advr_loss_valid = [] template = '{}_{}_{}'.format(expt, model, marker) log_head = '{} \t {} \t {}'.format( 'Epoch', 'Encd Tr', 'Encd Val', ) for _ in range(len(ally_classes)): log_head += ' \t {} \t {}'.format( 'A{} tr'.format(_), 'A{} val'.format(_)) for _ in range(len(advr_classes)): log_head += ' \t {} \t {}'.format( 'V{} tr'.format(_), 'V{} val'.format(_)) logging.info(log_head) encoder.train() for ally in allies: ally.train() for advr in advrs: advr.train() for epoch in range(n_epochs): nsamples = 0 iloss = 0 for i, data in tqdm(enumerate(dataloader_train, 0), total=len(dataloader_train)): X_train_torch = data[0].to(device) y_ally_train_torch = [ (data[1] % 2 == 0).type(torch.int64).to(device)] y_advr_train_torch = [ data[1].to(device), # (data[1] >= 5).type(torch.int64).to(device) ] optimizer_encd.zero_grad() # Forward pass X_train_encoded = encoder(X_train_torch) y_ally_train_hat_torch = [ally(X_train_encoded) for ally in allies] y_advr_train_hat_torch = [advr(X_train_encoded) for advr in advrs] # Compute Loss loss_ally = [criterionNLL(y_hat, y) for y_hat, y in zip(y_ally_train_hat_torch, y_ally_train_torch)] loss_advr = [criterionNLL(y_hat, y) for y_hat, y in zip( y_advr_train_hat_torch, y_advr_train_torch)] loss_encd = sum(loss_ally) + sum(loss_advr) # Backward pass loss_encd.backward() optimizer_encd.step() nsamples += 1 iloss += loss_encd.item() epochs_train.append(epoch) encd_loss_train.append(iloss/nsamples) nsamples = 0 iloss_ally = np.array([0] * len(allies)) iloss_advr = np.array([0] * len(advrs)) for i, data in tqdm(enumerate(dataloader_train, 0), total=len(dataloader_train)): X_train_torch = data[0].to(device) y_ally_train_torch = [ (data[1] % 2 == 0).type(torch.int64).to(device)] y_advr_train_torch = [ data[1].to(device), # (data[1] >= 5).type(torch.int64).to(device) ] [opt_ally.zero_grad() for opt_ally in optimizer_ally] X_train_encoded = encoder(X_train_torch) y_ally_train_hat_torch = [ally(X_train_encoded) for ally in allies] loss_ally = [criterionNLL(y_hat, y) for y_hat, y in zip(y_ally_train_hat_torch, y_ally_train_torch)] [l_ally.backward() for l_ally in loss_ally] [opt_ally.step() for opt_ally in optimizer_ally] [opt_advr.zero_grad() for opt_advr in optimizer_advr] X_train_encoded = encoder(X_train_torch) y_advr_train_hat_torch = [advr(X_train_encoded) for advr in advrs] loss_advr = [criterionNLL(y_hat, y) for y_hat, y in zip(y_advr_train_hat_torch, y_advr_train_torch)] [l_advr.backward(retain_graph=True) for l_advr in loss_advr] [opt_advr.step() for opt_advr in optimizer_advr] nsamples += 1 iloss_ally = iloss_ally + \ np.array([l_ally.item() for l_ally in loss_ally]) iloss_advr = iloss_advr + \ np.array([l_advr.item() for l_advr in loss_advr]) ally_loss_train.append(iloss_ally/nsamples) advr_loss_train.append(iloss_advr/nsamples) if epoch % int(n_epochs/10) != 0: continue nsamples = 0 iloss = 0 iloss_ally = np.array([0] * len(allies)) iloss_advr = np.array([0] * len(advrs)) for i, data in tqdm(enumerate(dataloader_valid, 0), total=len(dataloader_valid)): X_valid_torch = data[0].to(device) y_ally_valid_torch = [ (data[1] % 2 == 0).type(torch.int64).to(device)] y_advr_valid_torch = [ data[1].to(device), # (data[1] >= 5).type(torch.int64).to(device) ] X_valid_encoded = encoder(X_valid_torch) y_ally_valid_hat_torch = [ally(X_valid_encoded) for ally in allies] y_advr_valid_hat_torch = [advr(X_valid_encoded) for advr in advrs] # Compute Loss loss_ally = [criterionNLL(y_hat, y) for y_hat, y in zip(y_ally_valid_hat_torch, y_ally_valid_torch)] loss_advr = [criterionNLL(y_hat, y) for y_hat, y in zip(y_advr_valid_hat_torch, y_advr_valid_torch)] loss_encd = sum(loss_ally) - sum(loss_advr) if i < 4: sample = X_valid_torch[0].cpu().detach().squeeze().numpy() ax = plt.subplot(2, 4, i+1) plt.imshow(sample) ax.axis('off') output = X_valid_encoded[0].cpu().detach().squeeze().numpy() ax = plt.subplot(2, 4, i+5) plt.imshow(output) ax.axis('off') if i == 3: validation_plt = 'ckpts/{}/validation/{}_{}.jpg'.format( expt, template, epoch) print('Saving: {}'.format(validation_plt)) plt.savefig(validation_plt) nsamples += 1 iloss += loss_encd.item() iloss_ally = iloss_ally + \ np.array([l_ally.item() for l_ally in loss_ally]) iloss_advr = iloss_advr + \ np.array([l_advr.item() for l_advr in loss_advr]) epochs_valid.append(epoch) encd_loss_valid.append(iloss/nsamples) ally_loss_valid.append(iloss_ally/nsamples) advr_loss_valid.append(iloss_advr/nsamples) logging.info('{} \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f}'. format( epoch, encd_loss_train[-1], encd_loss_valid[-1], ally_loss_train[-1][0], ally_loss_valid[-1][0], advr_loss_train[-1][0], advr_loss_valid[-1][0], )) ally_loss_train = np.vstack(ally_loss_train) ally_loss_valid = np.vstack(ally_loss_valid) advr_loss_train = np.vstack(advr_loss_train) advr_loss_valid = np.vstack(advr_loss_valid) fig = plt.figure(figsize=(15, 4)) ax1 = fig.add_subplot(131) ax2 = fig.add_subplot(132) ax3 = fig.add_subplot(133) # ax4 = fig.add_subplot(224) ax1.plot(epochs_train, encd_loss_train, 'r', label='encd tr') ax1.plot(epochs_valid, encd_loss_valid, 'r--', label='encd val') ax1.legend() for col, c, ax in zip(range(ally_loss_train.shape[1]), ['b'], [ax2]): ax.plot(epochs_train, ally_loss_train[:, col], '{}.:'.format(c), label='ally {} tr'.format(col)) ax.plot(epochs_valid, ally_loss_valid[:, col], '{}s-.'.format(c), label='ally {} val'.format(col)) ax.legend() for col, c, ax in zip(range(advr_loss_train.shape[1]), ['g'], [ax3]): ax.plot(epochs_train, advr_loss_train[:, col], '{}.:'.format(c), label='advr {} tr'.format(col)) ax.plot(epochs_valid, advr_loss_valid[:, col], '{}s-.'.format(c), label='advr {} val'.format(col)) ax.legend() plot_location = 'ckpts/{}/plots/{}.png'.format( expt, template) sep() logging.info('Saving: {}'.format(plot_location)) plt.savefig(plot_location) checkpoint_location = 'ckpts/{}/history/{}.pkl'.format( expt, template) logging.info('Saving: {}'.format(checkpoint_location)) pkl.dump(( epochs_train, epochs_valid, encd_loss_train, encd_loss_valid, ally_loss_train, ally_loss_valid, advr_loss_train, advr_loss_valid, ), open(checkpoint_location, 'wb')) model_ckpt = 'ckpts/{}/models/{}.pkl'.format( expt, template) logging.info('Saving: {}'.format(model_ckpt)) torch.save(encoder.state_dict(), model_ckpt)