예제 #1
0
def main(
        model,
        time_stamp,
        device,
        ngpu,
        ally_classes,
        advr_1_classes,
        advr_2_classes,
        encoding_dim,
        hidden_dim,
        leaky,
        activation,
        test_size,
        batch_size,
        n_epochs,
        shuffle,
        init_weight,
        lr_encd,
        lr_ally,
        lr_advr_1,
        lr_advr_2,
        alpha,
        g_reps,
        d_reps,
        expt,
        marker
        ):

    device = torch_device(device=device)

    X, y_ally, y_advr_1, y_advr_2 = load_processed_data(
        expt, 'processed_data_X_y_ally_y_advr_y_advr_2.pkl')
    log_shapes(
        [X, y_ally, y_advr_1, y_advr_2],
        locals(),
        'Dataset loaded'
    )

    X_train, X_valid, \
        y_ally_train, y_ally_valid, \
        y_advr_1_train, y_advr_1_valid, \
        y_advr_2_train, y_advr_2_valid = train_test_split(
            X,
            y_ally,
            y_advr_1,
            y_advr_2,
            test_size=test_size,
            stratify=pd.DataFrame(np.concatenate(
                (
                    y_ally.reshape(-1, ally_classes),
                    y_advr_1.reshape(-1, advr_1_classes),
                    y_advr_2.reshape(-1, advr_2_classes),
                ), axis=1)
            )
        )

    log_shapes(
        [
            X_train, X_valid,
            y_ally_train, y_ally_valid,
            y_advr_1_train, y_advr_1_valid,
            y_advr_2_train, y_advr_2_valid,
        ],
        locals(),
        'Data size after train test split'
    )

    scaler = StandardScaler()
    X_normalized_train = scaler.fit_transform(X_train)
    X_normalized_valid = scaler.transform(X_valid)

    log_shapes([X_normalized_train, X_normalized_valid], locals())

    encoder = GeneratorFCN(
        X_normalized_train.shape[1], hidden_dim, encoding_dim,
        leaky, activation).to(device)
    ally = DiscriminatorFCN(
        encoding_dim, hidden_dim, ally_classes,
        leaky).to(device)
    advr_1 = DiscriminatorFCN(
        encoding_dim, hidden_dim, advr_1_classes,
        leaky).to(device)
    advr_2 = DiscriminatorFCN(
        encoding_dim, hidden_dim, advr_2_classes,
        leaky).to(device)

    if init_weight:
        sep()
        logging.info('applying weights_init ...')
        encoder.apply(weights_init)
        ally.apply(weights_init)
        advr_1.apply(weights_init)
        advr_2.apply(weights_init)

    sep('encoder')
    summary(encoder, input_size=(1, X_normalized_train.shape[1]))
    sep('ally')
    summary(ally, input_size=(1, encoding_dim))
    sep('advr_1')
    summary(advr_1, input_size=(1, encoding_dim))
    sep('advr_2')
    summary(advr_2, input_size=(1, encoding_dim))

    optim = torch.optim.Adam
    criterionBCEWithLogits = nn.BCEWithLogitsLoss()
    criterionCrossEntropy = nn.CrossEntropyLoss()

    optimizer_encd = optim(
        encoder.parameters(),
        lr=lr_encd,
        weight_decay=lr_encd
    )
    optimizer_ally = optim(
        ally.parameters(),
        lr=lr_ally,
        weight_decay=lr_ally
    )
    optimizer_advr_1 = optim(
        advr_1.parameters(),
        lr=lr_advr_1,
        weight_decay=lr_advr_1
    )
    optimizer_advr_2 = optim(
        advr_2.parameters(),
        lr=lr_advr_2,
        weight_decay=lr_advr_2
    )

    dataset_train = utils.TensorDataset(
        torch.Tensor(X_normalized_train),
        torch.Tensor(y_ally_train.reshape(-1, 1)),
        torch.Tensor(y_advr_1_train.reshape(-1, 1)),
        torch.Tensor(y_advr_2_train.reshape(-1, 3)),
    )

    dataloader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=1
    )

    dataset_valid = utils.TensorDataset(
        torch.Tensor(X_normalized_valid),
        torch.Tensor(y_ally_valid.reshape(-1, 1)),
        torch.Tensor(y_advr_1_valid.reshape(-1, 1)),
        torch.Tensor(y_advr_2_valid.reshape(-1, 3)),
    )

    dataloader_valid = torch.utils.data.DataLoader(
        dataset_valid,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=1
    )

    epochs_train = []
    epochs_valid = []
    encd_loss_train = []
    encd_loss_valid = []
    ally_loss_train = []
    ally_loss_valid = []
    advr_1_loss_train = []
    advr_1_loss_valid = []
    advr_2_loss_train = []
    advr_2_loss_valid = []

    logging.info('{} \t {} \t {} \t {} \t {} \t {} \t {} \t {} \t {}'.format(
                'Epoch',
                'Encd Train',
                'Encd Valid',
                'Ally Train',
                'Ally Valid',
                'Advr 1 Train',
                'Advr 1 Valid',
                'Advr 2 Train',
                'Advr 2 Valid',
                ))

    for epoch in range(n_epochs):

        encoder.train()
        ally.eval()
        advr_1.eval()
        advr_2.eval()

        for __ in range(g_reps):
            nsamples = 0
            iloss = 0
            for i, data in enumerate(dataloader_train, 0):
                X_train_torch = data[0].to(device)
                y_ally_train_torch = data[1].to(device)
                y_advr_1_train_torch = data[2].to(device)
                y_advr_2_train_torch = data[3].to(device)

                optimizer_encd.zero_grad()
                # Forward pass
                X_train_encoded = encoder(X_train_torch)
                y_ally_train_hat_torch = ally(X_train_encoded)
                y_advr_1_train_hat_torch = advr_1(X_train_encoded)
                y_advr_2_train_hat_torch = advr_2(X_train_encoded)
                # Compute Loss
                loss_ally = criterionBCEWithLogits(
                    y_ally_train_hat_torch, y_ally_train_torch)
                loss_advr_1 = criterionBCEWithLogits(
                    y_advr_1_train_hat_torch,
                    y_advr_1_train_torch)
                loss_advr_2 = criterionCrossEntropy(
                    y_advr_2_train_hat_torch,
                    torch.argmax(y_advr_2_train_torch, 1))
                loss_encd = loss_ally - loss_advr_1 - loss_advr_2
                # Backward pass
                loss_encd.backward()
                optimizer_encd.step()

                nsamples += 1
                iloss += loss_encd.item()

        epochs_train.append(epoch)
        encd_loss_train.append(iloss/nsamples)

        encoder.eval()
        ally.train()
        advr_1.train()
        advr_2.train()

        for __ in range(d_reps):
            nsamples = 0
            iloss_ally = 0
            iloss_advr_1 = 0
            iloss_advr_2 = 0
            for i, data in enumerate(dataloader_train, 0):
                X_train_torch = data[0].to(device)
                y_ally_train_torch = data[1].to(device)
                y_advr_1_train_torch = data[2].to(device)
                y_advr_2_train_torch = data[3].to(device)

                optimizer_ally.zero_grad()
                X_train_encoded = encoder(X_train_torch)
                y_ally_train_hat_torch = ally(X_train_encoded)
                loss_ally = criterionBCEWithLogits(
                    y_ally_train_hat_torch, y_ally_train_torch)
                loss_ally.backward()
                optimizer_ally.step()

                optimizer_advr_1.zero_grad()
                X_train_encoded = encoder(X_train_torch)
                y_advr_1_train_hat_torch = advr_1(X_train_encoded)
                loss_advr_1 = criterionBCEWithLogits(
                    y_advr_1_train_hat_torch,
                    y_advr_1_train_torch)
                loss_advr_1.backward()
                optimizer_advr_1.step()

                optimizer_advr_2.zero_grad()
                X_train_encoded = encoder(X_train_torch)
                y_advr_2_train_hat_torch = advr_2(X_train_encoded)
                loss_advr_2 = criterionCrossEntropy(
                    y_advr_2_train_hat_torch,
                    torch.argmax(y_advr_2_train_torch, 1))
                loss_advr_2.backward()
                optimizer_advr_2.step()

                nsamples += 1
                iloss_ally += loss_ally.item()
                iloss_advr_1 += loss_advr_1.item()
                iloss_advr_2 += loss_advr_2.item()

        ally_loss_train.append(iloss_ally/nsamples)
        advr_1_loss_train.append(iloss_advr_1/nsamples)
        advr_2_loss_train.append(iloss_advr_2/nsamples)

        if epoch % int(n_epochs/10) != 0:
            continue

        encoder.eval()
        ally.eval()
        advr_1.eval()
        advr_2.eval()

        nsamples = 0
        iloss = 0
        iloss_ally = 0
        iloss_advr_1 = 0
        iloss_advr_2 = 0

        for i, data in enumerate(dataloader_valid, 0):
            X_valid_torch = data[0].to(device)
            y_ally_valid_torch = data[1].to(device)
            y_advr_1_valid_torch = data[2].to(device)
            y_advr_2_valid_torch = data[3].to(device)

            X_valid_encoded = encoder(X_valid_torch)
            y_ally_valid_hat_torch = ally(X_valid_encoded)
            y_advr_1_valid_hat_torch = advr_1(X_valid_encoded)
            y_advr_2_valid_hat_torch = advr_2(X_valid_encoded)

            valid_loss_ally = criterionBCEWithLogits(
                y_ally_valid_hat_torch, y_ally_valid_torch)
            valid_loss_advr_1 = criterionBCEWithLogits(
                y_advr_1_valid_hat_torch, y_advr_1_valid_torch)
            valid_loss_advr_2 = criterionCrossEntropy(
                y_advr_2_valid_hat_torch,
                torch.argmax(y_advr_2_valid_torch, 1))
            valid_loss_encd = valid_loss_ally - valid_loss_advr_1 - \
                valid_loss_advr_2

            nsamples += 1
            iloss += valid_loss_encd.item()
            iloss_ally += valid_loss_ally.item()
            iloss_advr_1 += valid_loss_advr_1.item()
            iloss_advr_2 += valid_loss_advr_2.item()

        epochs_valid.append(epoch)
        encd_loss_valid.append(iloss/nsamples)
        ally_loss_valid.append(iloss_ally/nsamples)
        advr_1_loss_valid.append(iloss_advr_1/nsamples)
        advr_2_loss_valid.append(iloss_advr_2/nsamples)

        logging.info(
            '{} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f}'.
            format(
                epoch,
                encd_loss_train[-1],
                encd_loss_valid[-1],
                ally_loss_train[-1],
                ally_loss_valid[-1],
                advr_1_loss_train[-1],
                advr_1_loss_valid[-1],
                advr_2_loss_train[-1],
                advr_2_loss_valid[-1],
            ))

    config_summary = '{}_device_{}_dim_{}_hidden_{}_batch_{}_epochs_{}_lrencd_{}_lrally_{}_tr_{:.4f}_val_{:.4f}'\
        .format(
            marker,
            device,
            encoding_dim,
            hidden_dim,
            batch_size,
            n_epochs,
            lr_encd,
            lr_ally,
            encd_loss_train[-1],
            advr_1_loss_valid[-1],
        )

    plt.plot(epochs_train, encd_loss_train, 'r')
    plt.plot(epochs_valid, encd_loss_valid, 'r--')
    plt.plot(epochs_train, ally_loss_train, 'b')
    plt.plot(epochs_valid, ally_loss_valid, 'b--')
    plt.plot(epochs_train, advr_1_loss_train, 'g')
    plt.plot(epochs_valid, advr_1_loss_valid, 'g--')
    plt.plot(epochs_train, advr_2_loss_train, 'y')
    plt.plot(epochs_valid, advr_2_loss_valid, 'y--')
    plt.legend([
        'encoder train', 'encoder valid',
        'ally train', 'ally valid',
        'advr 1 train', 'advr 1 valid',
        'advr 2 train', 'advr 2 valid',
    ])
    plt.title("{} on {} training".format(model, expt))

    plot_location = 'plots/{}/{}_training_{}_{}.png'.format(
        expt, model, time_stamp, config_summary)
    sep()
    logging.info('Saving: {}'.format(plot_location))
    plt.savefig(plot_location)
    checkpoint_location = \
        'checkpoints/{}/{}_training_history_{}_{}.pkl'.format(
            expt, model, time_stamp, config_summary)
    logging.info('Saving: {}'.format(checkpoint_location))
    pkl.dump((
        epochs_train, epochs_valid,
        encd_loss_train, encd_loss_valid,
        ally_loss_train, ally_loss_valid,
        advr_1_loss_train, advr_1_loss_valid,
        advr_2_loss_train, advr_2_loss_valid,
    ), open(checkpoint_location, 'wb'))

    model_ckpt = 'checkpoints/{}/{}_torch_model_{}_{}.pkl'.format(
            expt, model, time_stamp, config_summary)
    logging.info('Saving: {}'.format(model_ckpt))
    torch.save(encoder, model_ckpt)
예제 #2
0
def main(
    model,
    time_stamp,
    device,
    ally_classes,
    advr_1_classes,
    advr_2_classes,
    encoding_dim,
    hidden_dim,
    leaky,
    test_size,
    batch_size,
    n_epochs,
    shuffle,
    lr_ally,
    lr_advr_1,
    lr_advr_2,
    expt,
    pca_ckpt,
    autoencoder_ckpt,
    encoder_ckpt,
):
    device = torch_device(device=device)

    X, y_ally, y_advr_1, y_advr_2 = load_processed_data(
        expt, 'processed_data_X_y_ally_y_advr_y_advr_2.pkl')
    log_shapes([X, y_ally, y_advr_1, y_advr_2], locals(), 'Dataset loaded')

    X_train, X_valid, \
        y_ally_train, y_ally_valid, \
        y_advr_1_train, y_advr_1_valid, \
        y_advr_2_train, y_advr_2_valid = train_test_split(
            X,
            y_ally,
            y_advr_1,
            y_advr_2,
            test_size=test_size,
            stratify=pd.DataFrame(np.concatenate(
                (
                    y_ally.reshape(-1, ally_classes),
                    y_advr_1.reshape(-1, advr_1_classes),
                    y_advr_2.reshape(-1, advr_2_classes),
                ), axis=1)
            )
        )

    log_shapes([
        X_train,
        X_valid,
        y_ally_train,
        y_ally_valid,
        y_advr_1_train,
        y_advr_1_valid,
        y_advr_2_train,
        y_advr_2_valid,
    ], locals(), 'Data size after train test split')

    scaler = StandardScaler()
    X_normalized_train = scaler.fit_transform(X_train)
    X_normalized_valid = scaler.transform(X_valid)

    log_shapes([X_normalized_train, X_normalized_valid], locals())

    encoder = torch.load(encoder_ckpt)
    encoder.eval()

    optim = torch.optim.Adam
    criterionBCEWithLogits = nn.BCEWithLogitsLoss()
    criterionCrossEntropy = nn.CrossEntropyLoss()

    h = {
        'epoch': {
            'train': [],
            'valid': [],
        },
        'encoder': {
            'ally_train': [],
            'ally_valid': [],
            'advr_1_train': [],
            'advr_1_valid': [],
            'advr_2_train': [],
            'advr_2_valid': [],
        },
    }

    for _ in ['encoder']:

        dataset_train = utils.TensorDataset(
            torch.Tensor(X_normalized_train),
            torch.Tensor(y_ally_train.reshape(-1, ally_classes)),
            torch.Tensor(y_advr_1_train.reshape(-1, advr_1_classes)),
            torch.Tensor(y_advr_2_train.reshape(-1, advr_2_classes)),
        )

        dataset_valid = utils.TensorDataset(
            torch.Tensor(X_normalized_valid),
            torch.Tensor(y_ally_valid.reshape(-1, ally_classes)),
            torch.Tensor(y_advr_1_valid.reshape(-1, advr_1_classes)),
            torch.Tensor(y_advr_2_valid.reshape(-1, advr_2_classes)),
        )

        def transform(input_arg):
            return encoder(input_arg)

        dataloader_train = torch.utils.data.DataLoader(dataset_train,
                                                       batch_size=batch_size,
                                                       shuffle=shuffle,
                                                       num_workers=1)

        dataloader_valid = torch.utils.data.DataLoader(dataset_valid,
                                                       batch_size=batch_size,
                                                       shuffle=shuffle,
                                                       num_workers=1)

        ally = DiscriminatorFCN(encoding_dim, hidden_dim, ally_classes,
                                leaky).to(device)
        advr_1 = DiscriminatorFCN(encoding_dim, hidden_dim, advr_1_classes,
                                  leaky).to(device)
        advr_2 = DiscriminatorFCN(encoding_dim, hidden_dim, advr_2_classes,
                                  leaky).to(device)

        ally.apply(weights_init)
        advr_1.apply(weights_init)
        advr_2.apply(weights_init)

        sep('{}:{}'.format(_, 'ally'))
        summary(ally, input_size=(1, encoding_dim))
        sep('{}:{}'.format(_, 'advr 1'))
        summary(advr_1, input_size=(1, encoding_dim))
        sep('{}:{}'.format(_, 'advr 2'))
        summary(advr_2, input_size=(1, encoding_dim))

        optimizer_ally = optim(ally.parameters(), lr=lr_ally)
        optimizer_advr_1 = optim(advr_1.parameters(), lr=lr_advr_1)
        optimizer_advr_2 = optim(advr_2.parameters(), lr=lr_advr_2)

        # adversary 1
        sep("adversary 1")
        logging.info('{} \t {} \t {}'.format(
            'Epoch',
            'Advr 1 Train',
            'Advr 1 Valid',
        ))

        for epoch in range(n_epochs):
            advr_1.train()

            nsamples = 0
            iloss_advr = 0
            for i, data in enumerate(dataloader_train, 0):
                X_train_torch = transform(data[0].to(device))
                y_advr_train_torch = data[2].to(device)

                optimizer_advr_1.zero_grad()
                y_advr_train_hat_torch = advr_1(X_train_torch)

                loss_advr = criterionBCEWithLogits(y_advr_train_hat_torch,
                                                   y_advr_train_torch)
                loss_advr.backward()
                optimizer_advr_1.step()

                nsamples += 1
                iloss_advr += loss_advr.item()

            h[_]['advr_1_train'].append(iloss_advr / nsamples)

            if epoch % int(n_epochs / 10) != 0:
                continue

            advr_1.eval()

            nsamples = 0
            iloss_advr = 0
            correct = 0
            total = 0

            for i, data in enumerate(dataloader_valid, 0):
                X_valid_torch = transform(data[0].to(device))
                y_advr_valid_torch = data[2].to(device)
                y_advr_valid_hat_torch = advr_1(X_valid_torch)

                valid_loss_advr = criterionBCEWithLogits(
                    y_advr_valid_hat_torch,
                    y_advr_valid_torch,
                )

                predicted = y_advr_valid_hat_torch > 0.5

                nsamples += 1
                iloss_advr += valid_loss_advr.item()
                total += y_advr_valid_torch.size(0)
                correct += (predicted == y_advr_valid_torch).sum().item()

            h[_]['advr_1_valid'].append(iloss_advr / nsamples)

            logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format(
                epoch, h[_]['advr_1_train'][-1], h[_]['advr_1_valid'][-1],
                correct / total))

        # adversary
        sep("adversary 2")
        logging.info('{} \t {} \t {}'.format(
            'Epoch',
            'Advr 2 Train',
            'Advr 2 Valid',
        ))

        for epoch in range(n_epochs):
            advr_2.train()

            nsamples = 0
            iloss_advr = 0
            for i, data in enumerate(dataloader_train, 0):
                X_train_torch = transform(data[0].to(device))
                y_advr_train_torch = data[3].to(device)

                optimizer_advr_2.zero_grad()
                y_advr_train_hat_torch = advr_2(X_train_torch)

                loss_advr = criterionBCEWithLogits(y_advr_train_hat_torch,
                                                   y_advr_train_torch)
                loss_advr.backward()
                optimizer_advr_2.step()

                nsamples += 1
                iloss_advr += loss_advr.item()

            h[_]['advr_2_train'].append(iloss_advr / nsamples)

            if epoch % int(n_epochs / 10) != 0:
                continue

            advr_2.eval()

            nsamples = 0
            iloss_advr = 0
            correct = 0
            total = 0

            for i, data in enumerate(dataloader_valid, 0):
                X_valid_torch = transform(data[0].to(device))
                y_advr_valid_torch = data[3].to(device)
                y_advr_valid_hat_torch = advr_2(X_valid_torch)

                valid_loss_advr = criterionBCEWithLogits(
                    y_advr_valid_hat_torch, y_advr_valid_torch)

                predicted = y_advr_valid_hat_torch > 0.5

                nsamples += 1
                iloss_advr += valid_loss_advr.item()
                total += y_advr_valid_torch.size(0)
                correct += (predicted == y_advr_valid_torch).sum().item()

            h[_]['advr_2_valid'].append(iloss_advr / nsamples)

            logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format(
                epoch, h[_]['advr_2_train'][-1], h[_]['advr_2_valid'][-1],
                correct / total))

        sep("ally")
        logging.info('{} \t {} \t {}'.format(
            'Epoch',
            'Ally Train',
            'Ally Valid',
        ))

        for epoch in range(n_epochs):
            ally.train()

            nsamples = 0
            iloss_ally = 0
            for i, data in enumerate(dataloader_train, 0):
                X_train_torch = transform(data[0].to(device))
                y_ally_train_torch = data[1].to(device)

                optimizer_ally.zero_grad()
                y_ally_train_hat_torch = ally(X_train_torch)
                loss_ally = criterionBCEWithLogits(y_ally_train_hat_torch,
                                                   y_ally_train_torch)
                loss_ally.backward()
                optimizer_ally.step()

                nsamples += 1
                iloss_ally += loss_ally.item()
            if epoch not in h['epoch']['train']:
                h['epoch']['train'].append(epoch)
            h[_]['ally_train'].append(iloss_ally / nsamples)

            if epoch % int(n_epochs / 10) != 0:
                continue

            ally.eval()

            nsamples = 0
            iloss_ally = 0
            correct = 0
            total = 0

            for i, data in enumerate(dataloader_valid, 0):
                X_valid_torch = transform(data[0].to(device))
                y_ally_valid_torch = data[1].to(device)
                y_ally_valid_hat_torch = ally(X_valid_torch)

                valid_loss_ally = criterionBCEWithLogits(
                    y_ally_valid_hat_torch, y_ally_valid_torch)

                predicted = y_ally_valid_hat_torch > 0.5

                nsamples += 1
                iloss_ally += valid_loss_ally.item()
                total += y_ally_valid_torch.size(0)
                correct += (predicted == y_ally_valid_torch).sum().item()

            if epoch not in h['epoch']['valid']:
                h['epoch']['valid'].append(epoch)
            h[_]['ally_valid'].append(iloss_ally / nsamples)

            logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format(
                epoch, h[_]['ally_train'][-1], h[_]['ally_valid'][-1],
                correct / total))

    checkpoint_location = \
        'checkpoints/{}/{}_training_history_{}.pkl'.format(
            expt, model, time_stamp)
    sep()
    logging.info('Saving: {}'.format(checkpoint_location))
    pkl.dump(h, open(checkpoint_location, 'wb'))
예제 #3
0
def main(
        model,
        time_stamp,
        device,
        ally_classes,
        advr_1_classes,
        advr_2_classes,
        encoding_dim,
        test_size,
        batch_size,
        n_epochs,
        shuffle,
        lr,
        expt,
        ):

    device = torch_device(device=device)

    # refer to PrivacyGAN_Titanic for data preparation
    X, y_ally, y_advr_1, y_advr_2 = load_processed_data(
        expt, 'processed_data_X_y_ally_y_advr_y_advr_2.pkl')
    log_shapes(
        [X, y_ally, y_advr_1, y_advr_2],
        locals(),
        'Dataset loaded'
    )

    X_train, X_valid = train_test_split(
            X,
            test_size=test_size,
            stratify=pd.DataFrame(np.concatenate(
                (
                    y_ally.reshape(-1, ally_classes),
                    y_advr_1.reshape(-1, advr_1_classes),
                    y_advr_2.reshape(-1, advr_2_classes),
                ), axis=1)
            )
        )

    log_shapes(
        [
            X_train, X_valid,
        ],
        locals(),
        'Data size after train test split'
    )

    scaler = StandardScaler()
    X_train_normalized = scaler.fit_transform(X_train)
    X_valid_normalized = scaler.transform(X_valid)

    log_shapes([X_train_normalized, X_valid_normalized], locals())

    dataset_train = utils.TensorDataset(torch.Tensor(X_train_normalized))
    dataloader_train = torch.utils.data.DataLoader(
        dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=2)

    dataset_valid = utils.TensorDataset(torch.Tensor(X_valid_normalized))
    dataloader_valid = torch.utils.data.DataLoader(
        dataset_valid, batch_size=batch_size, shuffle=False, num_workers=2)

    auto_encoder = AutoEncoderBasic(
        input_size=X_train_normalized.shape[1],
        encoding_dim=encoding_dim
    ).to(device)

    criterion = torch.nn.MSELoss()
    adam_optim = torch.optim.Adam
    optimizer = adam_optim(auto_encoder.parameters(), lr=lr)

    summary(auto_encoder, input_size=(1, X_train_normalized.shape[1]))

    h_epoch = []
    h_valid = []
    h_train = []

    auto_encoder.train()

    sep()
    logging.info("epoch \t Aencoder_train \t Aencoder_valid")

    for epoch in range(n_epochs):

        nsamples = 0
        iloss = 0
        for data in dataloader_train:
            optimizer.zero_grad()

            X_torch = data[0].to(device)
            X_torch_hat = auto_encoder(X_torch)
            loss = criterion(X_torch_hat, X_torch)
            loss.backward()
            optimizer.step()

            nsamples += 1
            iloss += loss.item()

        if epoch % int(n_epochs/10) != 0:
            continue

        h_epoch.append(epoch)
        h_train.append(iloss/nsamples)

        nsamples = 0
        iloss = 0
        for data in dataloader_valid:
            X_torch = data[0].to(device)
            X_torch_hat = auto_encoder(X_torch)
            loss = criterion(X_torch_hat, X_torch)

            nsamples += 1
            iloss += loss.item()
        h_valid.append(iloss/nsamples)

        logging.info('{} \t {:.8f} \t {:.8f}'.format(
            h_epoch[-1],
            h_train[-1],
            h_valid[-1],
        ))

    config_summary = 'device_{}_dim_{}_batch_{}_epochs_{}_lr_{}_tr_{:.4f}_val_{:.4f}'\
        .format(
            device,
            encoding_dim,
            batch_size,
            n_epochs,
            lr,
            h_train[-1],
            h_valid[-1],
        )

    plt.plot(h_epoch, h_train, 'r--')
    plt.plot(h_epoch, h_valid, 'b--')
    plt.legend(['train_loss', 'valid_loss'])
    plt.title("autoencoder training {}".format(config_summary))

    plot_location = 'plots/{}/{}_training_{}_{}.png'.format(
        expt, model, time_stamp, config_summary)
    sep()
    logging.info('Saving: {}'.format(plot_location))
    plt.savefig(plot_location)
    checkpoint_location = \
        'checkpoints/{}/{}_training_history_{}_{}.pkl'.format(
            expt, model, time_stamp, config_summary)
    logging.info('Saving: {}'.format(checkpoint_location))
    pkl.dump((h_epoch, h_train, h_valid), open(checkpoint_location, 'wb'))

    model_ckpt = 'checkpoints/{}/{}_torch_model_{}_{}.pkl'.format(
            expt, model, time_stamp, config_summary)
    logging.info('Saving: {}'.format(model_ckpt))
    torch.save(auto_encoder, model_ckpt)
예제 #4
0
def main(
        model,
        time_stamp,
        device,
        encoding_dim,
        hidden_dim,
        leaky,
        test_size,
        batch_size,
        n_epochs,
        shuffle,
        lr,
        expt,
        pca_ckpt,
        autoencoder_ckpt,
        encoder_ckpt,
        ):
    device = torch_device(device=device)

    X, targets = load_processed_data(
        expt, 'processed_data_X_targets.pkl')
    log_shapes(
        [X] + [targets[i] for i in targets],
        locals(),
        'Dataset loaded'
    )

    targets = {i: elem.reshape(-1, 1) for i, elem in targets.items()}

    X_train, X_valid, \
        y_adt_train, y_adt_valid = train_test_split(
            X,
            targets['admission_type'],
            test_size=test_size,
            stratify=pd.DataFrame(np.concatenate(
                (
                    targets['admission_type'],
                ), axis=1)
            )
        )

    log_shapes(
        [
            X_train, X_valid,
            y_adt_train, y_adt_valid,
        ],
        locals(),
        'Data size after train test split'
    )

    y_train = y_adt_train
    y_valid = y_adt_valid

    scaler = StandardScaler()
    X_normalized_train = scaler.fit_transform(X_train)
    X_normalized_valid = scaler.transform(X_valid)

    log_shapes([X_normalized_train, X_normalized_valid], locals())

    ckpts = {
        # 123A: checkpoints/mimic/n_ind_gan_training_history_02_03_2020_17_41_09.pkl
        # 0: 'checkpoints/mimic/n_eigan_torch_model_02_03_2020_16_13_39_A_n_1_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_0_encd_0.0471_advr_0.5991.pkl',
        # 28B: checkpoints/mimic/n_ind_gan_training_history_02_05_2020_00_23_29.pkl
        # 1: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_22_51_11_B_n_2_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_1_encd_0.0475_advr_0.5992.pkl',
        # 123A: checkpoints/mimic/n_ind_gan_training_history_02_03_2020_17_41_09.pkl
        # 2: 'checkpoints/mimic/n_eigan_torch_model_02_03_2020_16_14_37_A_n_3_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_2_encd_0.0464_advr_0.5991.pkl',
        # 224A: checkpoints/mimic/n_ind_gan_training_history_02_03_2020_20_09_39.pkl
        # 3: 'checkpoints/mimic/n_eigan_torch_model_02_03_2020_18_08_09_A_n_4_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_3_encd_0.0469_advr_0.5991.pkl',
        # 24A: checkpoints/mimic/n_ind_gan_training_history_02_04_2020_00_21_50.pkl
        # 4: 'checkpoints/mimic/n_eigan_torch_model_02_03_2020_23_12_05_A_n_5_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_4_encd_0.0468_advr_0.5994.pkl',
        # 67A: checkpoints/mimic/n_ind_gan_training_history_02_04_2020_05_30_09.pkl 
        # 5: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_00_15_28_A_n_6_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_5_encd_0.0462_advr_0.5991.pkl',
        # 67A: checkpoints/mimic/n_ind_gan_training_history_02_04_2020_05_30_09.pkl
        # 6: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_00_42_31_A_n_7_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_6_encd_0.0453_advr_0.5992.pkl',
        # 28B: checkpoints/mimic/n_ind_gan_training_history_02_05_2020_00_23_29.pkl
        # 7: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_22_54_21_B_n_8_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_7_encd_0.0477_advr_0.5992.pkl',
        # 9B: checkpoints/mimic/n_ind_gan_training_history_02_05_2020_18_09_03.pkl
        # 8: 'checkpoints/mimic/n_eigan_torch_model_02_05_2020_00_48_34_B_n_9_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_8_encd_0.0473_advr_0.6000.pkl',
        # nA: heckpoints/mimic/n_ind_gan_training_history_02_04_2020_20_13_29.pkl
        # 9: 'checkpoints/mimic/n_eigan_torch_model_02_04_2020_18_51_01_A_n_10_device_cuda_dim_256_hidden_512_batch_32768_epochs_1001_ally_9_encd_0.0420_advr_0.5992.pkl',
    }

    h = {}

    for idx, ckpt in ckpts.items():
        encoder = torch.load(ckpt, map_location=device)
        encoder.eval()

        optim = torch.optim.Adam
        criterionBCEWithLogits = nn.BCEWithLogitsLoss()

        h[idx] = {
            'epoch_train': [],
            'epoch_valid': [],
            'advr_train': [],
            'advr_valid': [],
        }

        dataset_train = utils.TensorDataset(
            torch.Tensor(X_normalized_train),
            torch.Tensor(y_train),
        )

        dataset_valid = utils.TensorDataset(
            torch.Tensor(X_normalized_valid),
            torch.Tensor(y_valid),
        )

        def transform(input_arg):
            return encoder(input_arg)

        dataloader_train = torch.utils.data.DataLoader(
            dataset_train,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=1
        )

        dataloader_valid = torch.utils.data.DataLoader(
            dataset_valid,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=1
        )

        clf = DiscriminatorFCN(
            encoding_dim, hidden_dim, 1,
            leaky).to(device)

        clf.apply(weights_init)

        sep('{} {}'.format(idx+1, 'ally'))
        summary(clf, input_size=(1, encoding_dim))

        optimizer = optim(clf.parameters(), lr=lr)

        # adversary 1
        sep("adversary with {} ally encoder".format(idx+1))
        logging.info('{} \t {} \t {}'.format(
            'Epoch',
            'Advr Train',
            'Advr Valid',
            ))

        for epoch in range(n_epochs):
            clf.train()

            nsamples = 0
            iloss_advr = 0
            for i, data in enumerate(dataloader_train, 0):
                X_train_torch = transform(data[0].to(device))
                y_advr_train_torch = data[1].to(device)

                optimizer.zero_grad()
                y_advr_train_hat_torch = clf(X_train_torch)

                loss_advr = criterionBCEWithLogits(
                    y_advr_train_hat_torch, y_advr_train_torch)
                loss_advr.backward()
                optimizer.step()

                nsamples += 1
                iloss_advr += loss_advr.item()

            h[idx]['advr_train'].append(iloss_advr/nsamples)
            h[idx]['epoch_train'].append(epoch)

            if epoch % int(n_epochs/10) != 0:
                continue

            clf.eval()

            nsamples = 0
            iloss_advr = 0
            correct = 0
            total = 0

            for i, data in enumerate(dataloader_valid, 0):
                X_valid_torch = transform(data[0].to(device))
                y_advr_valid_torch = data[1].to(device)
                y_advr_valid_hat_torch = clf(X_valid_torch)

                valid_loss_advr = criterionBCEWithLogits(
                    y_advr_valid_hat_torch, y_advr_valid_torch,)

                predicted = y_advr_valid_hat_torch > 0.5

                nsamples += 1
                iloss_advr += valid_loss_advr.item()
                total += y_advr_valid_torch.size(0)
                correct += (predicted == y_advr_valid_torch).sum().item()

            h[idx]['advr_valid'].append(iloss_advr/nsamples)
            h[idx]['epoch_valid'].append(epoch)

            logging.info(
                '{} \t {:.8f} \t {:.8f} \t {:.8f}'.
                format(
                    epoch,
                    h[idx]['advr_train'][-1],
                    h[idx]['advr_valid'][-1],
                    correct/total
                ))

    checkpoint_location = \
        'checkpoints/{}/{}_training_history_{}.pkl'.format(
            expt, model, time_stamp)
    sep()
    logging.info('Saving: {}'.format(checkpoint_location))
    pkl.dump(h, open(checkpoint_location, 'wb'))
예제 #5
0
def main(
    model,
    time_stamp,
    device,
    ally_classes,
    advr_1_classes,
    advr_2_classes,
    encoding_dim,
    hidden_dim,
    leaky,
    test_size,
    batch_size,
    n_epochs,
    shuffle,
    lr,
    expt,
    pca_ckpt,
    autoencoder_ckpt,
    encoder_ckpt,
):
    device = torch_device(device=device)

    X, targets = load_processed_data(expt, 'processed_data_X_targets.pkl')
    log_shapes([X] + [targets[i] for i in targets], locals(), 'Dataset loaded')

    h = {}

    for name, target in targets.items():

        sep(name)

        target = target.reshape(-1, 1)

        X_train, X_valid, \
            y_train, y_valid = train_test_split(
                X,
                target,
                test_size=test_size,
                stratify=target
            )

        log_shapes([
            X_train,
            X_valid,
            y_train,
            y_valid,
        ], locals(), 'Data size after train test split')

        scaler = StandardScaler()
        X_normalized_train = scaler.fit_transform(X_train)
        X_normalized_valid = scaler.transform(X_valid)

        log_shapes([X_normalized_train, X_normalized_valid], locals())

        optim = torch.optim.Adam
        criterionBCEWithLogits = nn.BCEWithLogitsLoss()

        h[name] = {
            'epoch_train': [],
            'epoch_valid': [],
            'y_train': [],
            'y_valid': [],
        }

        dataset_train = utils.TensorDataset(
            torch.Tensor(X_normalized_train),
            torch.Tensor(y_train.reshape(-1, 1)))

        dataset_valid = utils.TensorDataset(
            torch.Tensor(X_normalized_valid),
            torch.Tensor(y_valid.reshape(-1, 1)),
        )

        dataloader_train = torch.utils.data.DataLoader(dataset_train,
                                                       batch_size=batch_size,
                                                       shuffle=shuffle,
                                                       num_workers=1)

        dataloader_valid = torch.utils.data.DataLoader(dataset_valid,
                                                       batch_size=batch_size,
                                                       shuffle=shuffle,
                                                       num_workers=1)

        clf = DiscriminatorFCN(encoding_dim, hidden_dim, 1, leaky).to(device)

        clf.apply(weights_init)

        sep('{}:{}'.format(name, 'summary'))
        summary(clf, input_size=(1, encoding_dim))

        optimizer = optim(clf.parameters(), lr=lr)

        # adversary 1
        sep("TRAINING")
        logging.info('{} \t {} \t {}'.format(
            'Epoch',
            'Train',
            'Valid',
        ))

        for epoch in range(n_epochs):

            clf.train()

            nsamples = 0
            iloss_train = 0

            for i, data in enumerate(dataloader_train, 0):
                X_train_torch = data[0].to(device)
                y_train_torch = data[1].to(device)

                optimizer.zero_grad()
                y_train_hat_torch = clf(X_train_torch)

                loss_train = criterionBCEWithLogits(y_train_hat_torch,
                                                    y_train_torch)
                loss_train.backward()
                optimizer.step()

                nsamples += 1
                iloss_train += loss_train.item()

            h[name]['y_train'].append(iloss_train / nsamples)
            h[name]['epoch_train'].append(epoch)

            if epoch % int(n_epochs / 10) != 0:
                continue

            clf.eval()

            nsamples = 0
            iloss_valid = 0
            correct = 0
            total = 0

            for i, data in enumerate(dataloader_valid, 0):
                X_valid_torch = data[0].to(device)
                y_valid_torch = data[1].to(device)
                y_valid_hat_torch = clf(X_valid_torch)

                valid_loss = criterionBCEWithLogits(
                    y_valid_hat_torch,
                    y_valid_torch,
                )

                predicted = y_valid_hat_torch > 0.5

                nsamples += 1
                iloss_valid += valid_loss.item()
                total += y_valid_torch.size(0)
                correct += (predicted == y_valid_torch).sum().item()

            h[name]['y_valid'].append(iloss_valid / nsamples)
            h[name]['epoch_valid'].append(epoch)

            logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format(
                epoch, h[name]['y_train'][-1], h[name]['y_valid'][-1],
                correct / total))

    checkpoint_location = \
        'checkpoints/{}/{}_training_history_{}.pkl'.format(
            expt, model, time_stamp)
    sep()
    logging.info('Saving: {}'.format(checkpoint_location))
    pkl.dump(h, open(checkpoint_location, 'wb'))
예제 #6
0
def main(
        model,
        time_stamp,
        device,
        ngpu,
        encoding_dim,
        hidden_dim,
        leaky,
        activation,
        test_size,
        batch_size,
        n_epochs,
        shuffle,
        init_weight,
        lr_encd,
        lr_ally,
        lr_advr,
        alpha,
        expt,
        num_allies,
        marker
        ):

    device = torch_device(device=device)

    X, targets = load_processed_data(
        expt, 'processed_data_X_targets.pkl')
    log_shapes(
        [X] + [targets[i] for i in targets],
        locals(),
        'Dataset loaded'
    )

    targets = {i: elem.reshape(-1, 1) for i, elem in targets.items()}

    X_train, X_valid, \
        y_hef_train, y_hef_valid, \
        y_exf_train, y_exf_valid, \
        y_gdr_train, y_gdr_valid, \
        y_lan_train, y_lan_valid, \
        y_mar_train, y_mar_valid, \
        y_rel_train, y_rel_valid, \
        y_ins_train, y_ins_valid, \
        y_dis_train, y_dis_valid, \
        y_adl_train, y_adl_valid, \
        y_adt_train, y_adt_valid, \
        y_etn_train, y_etn_valid = train_test_split(
            X,
            targets['hospital_expire_flag'],
            targets['expire_flag'],
            targets['gender'],
            targets['language'],
            targets['marital_status'],
            targets['religion'],
            targets['insurance'],
            targets['discharge_location'],
            targets['admission_location'],
            targets['admission_type'],
            targets['ethnicity'],
            test_size=test_size,
            stratify=pd.DataFrame(np.concatenate(
                (
                    targets['admission_type'],
                ), axis=1)
            )
        )

    log_shapes(
        [
            X_train, X_valid,
            y_hef_train, y_hef_valid,
            y_exf_train, y_exf_valid,
            y_gdr_train, y_gdr_valid,
            y_lan_train, y_lan_valid,
            y_mar_train, y_mar_valid,
            y_rel_train, y_rel_valid,
            y_ins_train, y_ins_valid,
            y_dis_train, y_dis_valid,
            y_adl_train, y_adl_valid,
            y_adt_train, y_adt_valid,
            y_etn_train, y_etn_valid
        ],
        locals(),
        'Data size after train test split'
    )

    y_ally_trains = [
        y_hef_train,
        y_exf_train,
        y_gdr_train,
        y_lan_train,
        y_mar_train,
        y_rel_train,
        y_ins_train,
        y_dis_train,
        y_adl_train,
        y_etn_train,
    ]
    y_ally_valids = [
        y_hef_valid,
        y_exf_valid,
        y_gdr_valid,
        y_lan_valid,
        y_mar_valid,
        y_rel_valid,
        y_ins_valid,
        y_dis_valid,
        y_adl_valid,
        y_etn_valid,
    ]
    y_advr_train = y_adt_train
    y_advr_valid = y_adt_valid

    scaler = StandardScaler()
    X_normalized_train = scaler.fit_transform(X_train)
    X_normalized_valid = scaler.transform(X_valid)

    log_shapes([X_normalized_train, X_normalized_valid], locals())

    for i in [num_allies-1]:
        sep('NUMBER OF ALLIES: {}'.format(i+1))
        encoder = GeneratorFCN(
            X_normalized_train.shape[1], hidden_dim, encoding_dim,
            leaky, activation).to(device)
        ally = {}
        for j in range(i+1):
            ally[j] = DiscriminatorFCN(
                encoding_dim, hidden_dim, 1,
                leaky).to(device)
        advr = DiscriminatorFCN(
            encoding_dim, hidden_dim, 1,
            leaky).to(device)

        if init_weight:
            sep()
            logging.info('applying weights_init ...')
            encoder.apply(weights_init)
            for j in range(i+1):
                ally[j].apply(weights_init)
            advr.apply(weights_init)

        sep('encoder')
        summary(encoder, input_size=(1, X_normalized_train.shape[1]))
        for j in range(i+1):
            sep('ally:{}'.format(j))
            summary(ally[j], input_size=(1, encoding_dim))
        sep('advr')
        summary(advr, input_size=(1, encoding_dim))

        optim = torch.optim.Adam
        criterionBCEWithLogits = nn.BCEWithLogitsLoss()


        optimizer_encd = optim(encoder.parameters(), lr=lr_encd)
        optimizer_ally = {}
        for j in range(i+1):
            optimizer_ally[j] = optim(ally[j].parameters(), lr=lr_ally)
        optimizer_advr = optim(advr.parameters(), lr=lr_advr)

        dataset_train = utils.TensorDataset(
            torch.Tensor(X_normalized_train),
            torch.Tensor(y_advr_train),
        )
        for y_ally_train in y_ally_trains:
            dataset_train.tensors = (*dataset_train.tensors, torch.Tensor(y_ally_train))

        dataloader_train = torch.utils.data.DataLoader(
            dataset_train,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=1
        )

        dataset_valid = utils.TensorDataset(
            torch.Tensor(X_normalized_valid),
            torch.Tensor(y_advr_valid),
        )
        for y_ally_valid in y_ally_valids:
            dataset_valid.tensors = (*dataset_valid.tensors, torch.Tensor(y_ally_valid))

        dataloader_valid = torch.utils.data.DataLoader(
            dataset_valid,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=1
        )

        epochs_train = []
        epochs_valid = []
        encd_loss_train = []
        encd_loss_valid = []
        ally_loss_train = {}
        ally_loss_valid = {}
        for j in range(i+1):
            ally_loss_train[j] = []
            ally_loss_valid[j] = []
        advr_loss_train = []
        advr_loss_valid = []

        log_list = ['epoch', 'encd_train', 'encd_valid', 'advr_train', 'advr_valid'] + \
            ['ally_{}_train \t ally_{}_valid'.format(str(j), str(j)) for j in range(i+1)]
        logging.info(' \t '.join(log_list))

        for epoch in range(n_epochs):

            encoder.train()
            for j in range(i+1):
                ally[i].eval()
            advr.eval()

            nsamples = 0
            iloss = 0
            for data in dataloader_train:
                X_train_torch = data[0].to(device)
                y_advr_train_torch = data[1].to(device)
                y_ally_train_torch = {}
                for j in range(i+1):
                    y_ally_train_torch[j] = data[j+2].to(device)

                optimizer_encd.zero_grad()
                # Forward pass
                X_train_encoded = encoder(X_train_torch)
                y_advr_train_hat_torch = advr(X_train_encoded)
                y_ally_train_hat_torch = {}
                for j in range(i+1):
                    y_ally_train_hat_torch[j] = ally[j](X_train_encoded)
                # Compute Loss
                loss_ally = {}
                for j in range(i+1):
                    loss_ally[j] = criterionBCEWithLogits(
                        y_ally_train_hat_torch[j], y_ally_train_torch[j])
                loss_advr = criterionBCEWithLogits(
                    y_advr_train_hat_torch,
                    y_advr_train_torch)

                loss_encd = alpha/num_allies * sum([loss_ally[_].item() for _ in loss_ally]) - (1-alpha) * loss_advr
                # Backward pass
                loss_encd.backward()
                optimizer_encd.step()

                nsamples += 1
                iloss += loss_encd.item()

            epochs_train.append(epoch)
            encd_loss_train.append(iloss/nsamples)

            encoder.eval()
            for j in range(i+1):
                ally[j].train()
            advr.train()

            nsamples = 0
            iloss_ally = {}
            for j in range(i+1):
                iloss_ally[j] = 0
            iloss_advr = 0
            for data in dataloader_train:
                X_train_torch = data[0].to(device)
                y_advr_train_torch = data[1].to(device)
                y_ally_train_torch = {}
                for j in range(i+1):
                    y_ally_train_torch[j] = data[j+2].to(device)

                y_ally_train_hat_torch = {}
                loss_ally = {}
                for j in range(i+1):
                    optimizer_ally[j].zero_grad()
                    X_train_encoded = encoder(X_train_torch)
                    y_ally_train_hat_torch[j] = ally[j](X_train_encoded)
                    loss_ally[j] = criterionBCEWithLogits(
                        y_ally_train_hat_torch[j], y_ally_train_torch[j])
                    loss_ally[j].backward()
                    optimizer_ally[j].step()

                optimizer_advr.zero_grad()
                X_train_encoded = encoder(X_train_torch)
                y_advr_train_hat_torch = advr(X_train_encoded)
                loss_advr = criterionBCEWithLogits(
                    y_advr_train_hat_torch,
                    y_advr_train_torch)
                loss_advr.backward()
                optimizer_advr.step()

                nsamples += 1
                for j in range(i+1):
                    iloss_ally[j] += loss_ally[j].item()
                iloss_advr += loss_advr.item()

            for j in range(i+1):
                ally_loss_train[j].append(iloss_ally[j]/nsamples)
            advr_loss_train.append(iloss_advr/nsamples)

            if epoch % int(n_epochs/10) != 0:
                continue

            encoder.eval()
            for j in range(i+1):
                ally[j].eval()
            advr.eval()

            nsamples = 0
            iloss = 0
            iloss_ally = {}
            for j in range(i+1):
                iloss_ally[j] = 0
            iloss_advr = 0

            for data in dataloader_valid:
                X_valid_torch = data[0].to(device)
                y_advr_valid_torch = data[1].to(device)
                y_ally_valid_torch = {}
                for j in range(i+1):
                    y_ally_valid_torch[j] = data[j+2].to(device)

                X_valid_encoded = encoder(X_valid_torch)
                y_ally_valid_hat_torch = {}
                for j in  range(i+1):
                    y_ally_valid_hat_torch[j] = ally[j](X_valid_encoded)
                y_advr_valid_hat_torch = advr(X_valid_encoded)

                valid_loss_ally = {}
                for j in range(i+1):
                    valid_loss_ally[j] = criterionBCEWithLogits(
                        y_ally_valid_hat_torch[j], y_ally_valid_torch[j])
                valid_loss_advr = criterionBCEWithLogits(
                    y_advr_valid_hat_torch, y_advr_valid_torch)
                valid_loss_encd = alpha/num_allies*sum(
                    [valid_loss_ally[_].item() for _ in valid_loss_ally]
                ) - (1-alpha)* valid_loss_advr

                nsamples += 1
                iloss += valid_loss_encd.item()
                for j in range(i+1):
                    iloss_ally[j] += valid_loss_ally[j].item()
                iloss_advr += valid_loss_advr.item()

            epochs_valid.append(epoch)
            encd_loss_valid.append(iloss/nsamples)
            for j in range(i+1):
                ally_loss_valid[j].append(iloss_ally[j]/nsamples)
            advr_loss_valid.append(iloss_advr/nsamples)

            log_line = [str(epoch), '{:.8f}'.format(encd_loss_train[-1]), '{:.8f}'.format(encd_loss_valid[-1]),
                '{:.8f}'.format(advr_loss_train[-1]), '{:.8f}'.format(advr_loss_valid[-1]),
            ] + \
            [
                '{:.8f} \t {:.8f}'.format(
                    ally_loss_train[_][-1],
                    ally_loss_valid[_][-1]
                ) for _ in ally_loss_train]
            logging.info(' \t '.join(log_line))

        config_summary = '{}_n_{}_device_{}_dim_{}_hidden_{}_batch_{}_epochs_{}_ally_{}_encd_{:.4f}_advr_{:.4f}'\
            .format(
                marker,
                num_allies,
                device,
                encoding_dim,
                hidden_dim,
                batch_size,
                n_epochs,
                i,
                encd_loss_train[-1],
                advr_loss_valid[-1],
            )

        plt.figure()
        plt.plot(epochs_train, encd_loss_train, 'r', label='encd train')
        plt.plot(epochs_valid, encd_loss_valid, 'r--', label='encd valid')
        # sum_loss = [0] * len(ally_loss_train[0])
        # for j in range(i+1):
        #     for k in ally_loss_valid:
        #         sum_loss[j] += ally_loss_train[j]
        # sum_loss = [sum_loss[j]/len(ally_loss_valid) for j in sum_loss]
        # plt.plot(
        #     epochs_train,
        #     sum([ally_loss_train[j] for j in range(i+1)])/len(ally_loss_train),
        #     'b', label='ally_sum_train')
        # sum_loss = [0] * len(ally_loss_valid)
        # for j in range(i+1):
        #     sum_loss[j] += ally_loss_valid[j]
        # sum_loss = [sum_loss[j]/len(ally_loss_valid) for j in sum_loss]
        # plt.plot(epochs_valid, sum_loss, 'b', label='ally_sum_valid')
        plt.plot(epochs_train, advr_loss_train, 'g', label='advr_train')
        plt.plot(epochs_valid, advr_loss_valid, 'g--', label='advr_valid')
        plt.legend()
        plt.title("{} on {} training".format(model, expt))

        plot_location = 'plots/{}/{}_training_{}_{}.png'.format(
            expt, model, time_stamp, config_summary)
        sep()
        logging.info('Saving: {}'.format(plot_location))
        plt.savefig(plot_location)
        checkpoint_location = \
            'checkpoints/{}/{}_training_history_{}_{}.pkl'.format(
                expt, model, time_stamp, config_summary)
        logging.info('Saving: {}'.format(checkpoint_location))
        pkl.dump((
            epochs_train, epochs_valid,
            encd_loss_train, encd_loss_valid,
            ally_loss_train, ally_loss_valid,
            advr_loss_train, advr_loss_valid,
        ), open(checkpoint_location, 'wb'))

        model_ckpt = 'checkpoints/{}/{}_torch_model_{}_{}.pkl'.format(
                expt, model, time_stamp, config_summary)
        logging.info('Saving: {}'.format(model_ckpt))
        torch.save(encoder, model_ckpt)
예제 #7
0
def main(
    model,
    device,
    encoding_dim,
    batch_size,
    n_epochs,
    shuffle,
    lr,
    expt,
):

    device = torch_device(device=device)

    X_train, X_valid, \
        y_train, y_valid = get_data(expt)

    dataset_train = utils.TensorDataset(
        torch.Tensor(X_train.reshape(cfg.num_trains[expt], -1)))
    dataloader_train = torch.utils.data.DataLoader(dataset_train,
                                                   batch_size=batch_size,
                                                   shuffle=shuffle,
                                                   num_workers=2)

    dataset_valid = utils.TensorDataset(
        torch.Tensor(X_valid.reshape(cfg.num_tests[expt], -1)))
    dataloader_valid = torch.utils.data.DataLoader(dataset_valid,
                                                   batch_size=batch_size,
                                                   shuffle=False,
                                                   num_workers=2)

    auto_encoder = AutoEncoderBasic(input_size=cfg.input_sizes[expt],
                                    encoding_dim=encoding_dim).to(device)

    criterion = torch.nn.MSELoss()
    adam_optim = torch.optim.Adam
    optimizer = adam_optim(auto_encoder.parameters(), lr=lr)

    summary(auto_encoder, input_size=(1, cfg.input_sizes[expt]))

    h_epoch = []
    h_valid = []
    h_train = []

    auto_encoder.train()

    sep()
    logging.info("epoch \t train \t valid")

    best = math.inf
    config_summary = 'device_{}_dim_{}_batch_{}_epochs_{}_lr_{}'.format(
        device,
        encoding_dim,
        batch_size,
        n_epochs,
        lr,
    )

    for epoch in range(n_epochs):

        nsamples = 0
        iloss = 0
        for data in dataloader_train:
            optimizer.zero_grad()

            X_torch = data[0].to(device)
            X_torch_hat = auto_encoder(X_torch)
            loss = criterion(X_torch_hat, X_torch)
            loss.backward()
            optimizer.step()

            nsamples += 1
            iloss += loss.item()

        if epoch % int(n_epochs / 10) != 0:
            continue

        h_epoch.append(epoch)
        h_train.append(iloss / nsamples)

        nsamples = 0
        iloss = 0
        for data in dataloader_valid:
            X_torch = data[0].to(device)
            X_torch_hat = auto_encoder(X_torch)
            loss = criterion(X_torch_hat, X_torch)

            nsamples += 1
            iloss += loss.item()
        h_valid.append(iloss / nsamples)
        if h_valid[-1] < best:
            best = h_valid[-1]

            model_ckpt = 'ckpts/{}/models/{}_{}_{}.best'.format(
                expt, model, config_summary, marker)
            logging.info('Saving: {}'.format(model_ckpt))
            torch.save(auto_encoder.state_dict(), model_ckpt)

        logging.info('{} \t {:.8f} \t {:.8f}'.format(
            h_epoch[-1],
            h_train[-1],
            h_valid[-1],
        ))

    fig = plt.figure(figsize=(5, 4))
    ax = fig.add_subplot(111)
    ax.plot(h_epoch, h_train, 'r.:')
    ax.plot(h_epoch, h_valid, 'rs-.')
    ax.set_xlabel('epochs')
    ax.set_ylabel('loss (MSEE)')
    plt.legend(['train loss', 'valid loss'])

    plot_location = 'ckpts/{}/plots/{}_{}_{}.png'.format(
        expt, model, config_summary, marker)
    sep()
    logging.info('Saving: {}'.format(plot_location))
    plt.savefig(plot_location)
    checkpoint_location = 'ckpts/{}/history/{}_{}_{}.pkl'.format(
        expt, model, config_summary, marker)
    logging.info('Saving: {}'.format(checkpoint_location))
    pkl.dump((h_epoch, h_train, h_valid), open(checkpoint_location, 'wb'))

    model_ckpt = 'ckpts/{}/models/{}_{}_{}.stop'.format(
        expt, model, config_summary, marker)
    logging.info('Saving: {}'.format(model_ckpt))
    torch.save(auto_encoder.state_dict(), model_ckpt)
예제 #8
0
def main(model, time_stamp, device, ngpu, num_nodes, ally_classes,
         advr_1_classes, advr_2_classes, encoding_dim, hidden_dim, leaky,
         activation, test_size, batch_size, n_epochs, shuffle, init_weight,
         lr_encd, lr_ally, lr_advr_1, lr_advr_2, alpha, g_reps, d_reps, expt,
         marker):

    device = torch_device(device=device)

    X, y_ally, y_advr_1, y_advr_2 = load_processed_data(
        expt, 'processed_data_X_y_ally_y_advr_y_advr_2.pkl')
    log_shapes([X, y_ally, y_advr_1, y_advr_2], locals(), 'Dataset loaded')

    X_train, X_valid, \
        y_ally_train, y_ally_valid, \
        y_advr_1_train, y_advr_1_valid, \
        y_advr_2_train, y_advr_2_valid = train_test_split(
            X,
            y_ally,
            y_advr_1,
            y_advr_2,
            test_size=test_size,
            stratify=pd.DataFrame(np.concatenate(
                (
                    y_ally.reshape(-1, ally_classes),
                    y_advr_1.reshape(-1, advr_1_classes),
                    y_advr_2.reshape(-1, advr_2_classes),
                ), axis=1)
            )
        )

    log_shapes([
        X_train,
        X_valid,
        y_ally_train,
        y_ally_valid,
        y_advr_1_train,
        y_advr_1_valid,
        y_advr_2_train,
        y_advr_2_valid,
    ], locals(), 'Data size after train test split')

    num_features = X_train.shape[1]
    split_pts = num_features // num_nodes
    X_train, X_valid = np.split(X_train, num_nodes,
                                axis=1), np.split(X_valid, num_nodes, axis=1)

    log_shapes(X_train + X_valid, locals(),
               'Data size after splitting data among nodes')

    scaler = StandardScaler()
    X_normalized_train, X_normalized_valid = [], []
    for train, valid in zip(X_train, X_valid):
        X_normalized_train.append(scaler.fit_transform(train))
        X_normalized_valid.append(scaler.transform(valid))

    log_shapes(X_normalized_train + X_normalized_valid, locals())

    encoders = []
    allies = []
    adversaries_1 = []
    adversaries_2 = []
    for train in X_normalized_train:
        encoders.append(
            GeneratorFCN(train.shape[1], hidden_dim, encoding_dim // num_nodes,
                         leaky, activation).to(device))
        allies.append(
            DiscriminatorFCN(encoding_dim // num_nodes, hidden_dim,
                             ally_classes, leaky).to(device))
        adversaries_1.append(
            DiscriminatorFCN(encoding_dim // num_nodes, hidden_dim,
                             advr_1_classes, leaky).to(device))
        adversaries_2.append(
            DiscriminatorFCN(encoding_dim // num_nodes, hidden_dim,
                             advr_2_classes, leaky).to(device))

    sep('encoders')
    for k in range(num_nodes):
        summary(encoders[k], input_size=(1, X_normalized_train[k].shape[1]))
    sep('ally')
    for k in range(num_nodes):
        summary(allies[k], input_size=(1, encoding_dim // num_nodes))
    sep('advr_1')
    for k in range(num_nodes):
        summary(adversaries_1[k], input_size=(1, encoding_dim // num_nodes))
    sep('advr_2')
    for k in range(num_nodes):
        summary(adversaries_2[k], input_size=(1, encoding_dim // num_nodes))

    optim = torch.optim.Adam
    criterionBCEWithLogits = nn.BCEWithLogitsLoss()

    optimizers_encd = []
    for encoder in encoders:
        optimizers_encd.append(optim(encoder.parameters(), lr=lr_encd))
    optimizers_ally = []
    for ally in allies:
        optimizers_ally.append(optim(ally.parameters(), lr=lr_ally))
    optimizers_advr_1 = []
    for advr_1 in adversaries_1:
        optimizers_advr_1.append(optim(advr_1.parameters(), lr=lr_advr_1))
    optimizers_advr_2 = []
    for advr_2 in adversaries_2:
        optimizers_advr_2.append(optim(advr_2.parameters(), lr=lr_advr_2))

    for k in range(num_nodes):
        sep('Node {}'.format(k))
        dataset_train = utils.TensorDataset(
            torch.Tensor(X_normalized_train[k]),
            torch.Tensor(y_ally_train.reshape(-1, ally_classes)),
            torch.Tensor(y_advr_1_train.reshape(-1, advr_1_classes)),
            torch.Tensor(y_advr_2_train.reshape(-1, advr_2_classes)),
        )

        dataloader_train = torch.utils.data.DataLoader(dataset_train,
                                                       batch_size=batch_size,
                                                       shuffle=shuffle,
                                                       num_workers=1)

        dataset_valid = utils.TensorDataset(
            torch.Tensor(X_normalized_valid[k]),
            torch.Tensor(y_ally_valid.reshape(-1, ally_classes)),
            torch.Tensor(y_advr_1_valid.reshape(-1, advr_1_classes)),
            torch.Tensor(y_advr_2_valid.reshape(-1, advr_2_classes)),
        )

        dataloader_valid = torch.utils.data.DataLoader(dataset_valid,
                                                       batch_size=batch_size,
                                                       shuffle=shuffle,
                                                       num_workers=1)

        epochs_train = []
        epochs_valid = []
        encd_loss_train = []
        encd_loss_valid = []
        ally_loss_train = []
        ally_loss_valid = []
        advr_1_loss_train = []
        advr_1_loss_valid = []
        advr_2_loss_train = []
        advr_2_loss_valid = []

        logging.info(
            '{} \t {} \t {} \t {} \t {} \t {} \t {} \t {} \t {}'.format(
                'Epoch',
                'Encd Train',
                'Encd Valid',
                'Ally Train',
                'Ally Valid',
                'Advr 1 Train',
                'Advr 1 Valid',
                'Advr 2 Train',
                'Advr 2 Valid',
            ))

        encoder = encoders[k]
        ally = allies[k]
        advr_1 = adversaries_1[k]
        advr_2 = adversaries_2[k]

        optimizer_encd = optimizers_encd[k]
        optimizer_ally = optimizers_ally[k]
        optimizer_advr_1 = optimizers_advr_1[k]
        optimizer_advr_2 = optimizers_advr_2[k]

        for epoch in range(n_epochs):

            encoder.train()
            ally.eval()
            advr_1.eval()
            advr_2.eval()

            for __ in range(g_reps):
                nsamples = 0
                iloss = 0
                for i, data in enumerate(dataloader_train, 0):
                    X_train_torch = data[0].to(device)
                    y_ally_train_torch = data[1].to(device)
                    y_advr_1_train_torch = data[2].to(device)
                    y_advr_2_train_torch = data[3].to(device)

                    optimizer_encd.zero_grad()
                    # Forward pass
                    X_train_encoded = encoder(X_train_torch)
                    y_ally_train_hat_torch = ally(X_train_encoded)
                    y_advr_1_train_hat_torch = advr_1(X_train_encoded)
                    y_advr_2_train_hat_torch = advr_2(X_train_encoded)
                    # Compute Loss
                    loss_ally = criterionBCEWithLogits(y_ally_train_hat_torch,
                                                       y_ally_train_torch)
                    loss_advr_1 = criterionBCEWithLogits(
                        y_advr_1_train_hat_torch, y_advr_1_train_torch)
                    loss_advr_2 = criterionBCEWithLogits(
                        y_advr_2_train_hat_torch, y_advr_2_train_torch)
                    loss_encd = alpha * loss_ally - (1 - alpha) / 2 * (
                        loss_advr_1 + loss_advr_2)
                    # Backward pass
                    loss_encd.backward()
                    optimizer_encd.step()

                    nsamples += 1
                    iloss += loss_encd.item()

            epochs_train.append(epoch)
            encd_loss_train.append(iloss / nsamples)

            encoder.eval()
            ally.train()
            advr_1.train()
            advr_2.train()

            for __ in range(d_reps):
                nsamples = 0
                iloss_ally = 0
                iloss_advr_1 = 0
                iloss_advr_2 = 0
                for i, data in enumerate(dataloader_train, 0):
                    X_train_torch = data[0].to(device)
                    y_ally_train_torch = data[1].to(device)
                    y_advr_1_train_torch = data[2].to(device)
                    y_advr_2_train_torch = data[3].to(device)

                    optimizer_ally.zero_grad()
                    X_train_encoded = encoder(X_train_torch)
                    y_ally_train_hat_torch = ally(X_train_encoded)
                    loss_ally = criterionBCEWithLogits(y_ally_train_hat_torch,
                                                       y_ally_train_torch)
                    loss_ally.backward()
                    optimizer_ally.step()

                    optimizer_advr_1.zero_grad()
                    X_train_encoded = encoder(X_train_torch)
                    y_advr_1_train_hat_torch = advr_1(X_train_encoded)
                    loss_advr_1 = criterionBCEWithLogits(
                        y_advr_1_train_hat_torch, y_advr_1_train_torch)
                    loss_advr_1.backward()
                    optimizer_advr_1.step()

                    optimizer_advr_2.zero_grad()
                    X_train_encoded = encoder(X_train_torch)
                    y_advr_2_train_hat_torch = advr_2(X_train_encoded)
                    loss_advr_2 = criterionBCEWithLogits(
                        y_advr_2_train_hat_torch, y_advr_2_train_torch)
                    loss_advr_2.backward()
                    optimizer_advr_2.step()

                    nsamples += 1
                    iloss_ally += loss_ally.item()
                    iloss_advr_1 += loss_advr_1.item()
                    iloss_advr_2 += loss_advr_2.item()

            ally_loss_train.append(iloss_ally / nsamples)
            advr_1_loss_train.append(iloss_advr_1 / nsamples)
            advr_2_loss_train.append(iloss_advr_2 / nsamples)

            if epoch % int(n_epochs / 10) != 0:
                continue

            encoder.eval()
            ally.eval()
            advr_1.eval()
            advr_2.eval()

            nsamples = 0
            iloss = 0
            iloss_ally = 0
            iloss_advr_1 = 0
            iloss_advr_2 = 0

            for i, data in enumerate(dataloader_valid, 0):
                X_valid_torch = data[0].to(device)
                y_ally_valid_torch = data[1].to(device)
                y_advr_1_valid_torch = data[2].to(device)
                y_advr_2_valid_torch = data[3].to(device)

                X_valid_encoded = encoder(X_valid_torch)
                y_ally_valid_hat_torch = ally(X_valid_encoded)
                y_advr_1_valid_hat_torch = advr_1(X_valid_encoded)
                y_advr_2_valid_hat_torch = advr_2(X_valid_encoded)

                valid_loss_ally = criterionBCEWithLogits(
                    y_ally_valid_hat_torch, y_ally_valid_torch)
                valid_loss_advr_1 = criterionBCEWithLogits(
                    y_advr_1_valid_hat_torch, y_advr_1_valid_torch)
                valid_loss_advr_2 = criterionBCEWithLogits(
                    y_advr_2_valid_hat_torch, y_advr_2_valid_torch)
                valid_loss_encd = alpha*valid_loss_ally - (1-alpha)/2*(valid_loss_advr_1 + \
                    valid_loss_advr_2)

                nsamples += 1
                iloss += valid_loss_encd.item()
                iloss_ally += valid_loss_ally.item()
                iloss_advr_1 += valid_loss_advr_1.item()
                iloss_advr_2 += valid_loss_advr_2.item()

            epochs_valid.append(epoch)
            encd_loss_valid.append(iloss / nsamples)
            ally_loss_valid.append(iloss_ally / nsamples)
            advr_1_loss_valid.append(iloss_advr_1 / nsamples)
            advr_2_loss_valid.append(iloss_advr_2 / nsamples)

            logging.info(
                '{} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f} \t {:.8f}'
                .format(
                    epoch,
                    encd_loss_train[-1],
                    encd_loss_valid[-1],
                    ally_loss_train[-1],
                    ally_loss_valid[-1],
                    advr_1_loss_train[-1],
                    advr_1_loss_valid[-1],
                    advr_2_loss_train[-1],
                    advr_2_loss_valid[-1],
                ))

        config_summary = '{}_node_{}_{}_device_{}_dim_{}_hidden_{}_batch_{}_epochs_{}_lrencd_{}_lrally_{}_tr_{:.4f}_val_{:.4f}'\
            .format(
                marker,
                num_nodes,
                k,
                device,
                encoding_dim,
                hidden_dim,
                batch_size,
                n_epochs,
                lr_encd,
                lr_ally,
                encd_loss_train[-1],
                advr_1_loss_valid[-1],
            )

        plt.plot(epochs_train, encd_loss_train, 'r')
        plt.plot(epochs_valid, encd_loss_valid, 'r--')
        plt.plot(epochs_train, ally_loss_train, 'b')
        plt.plot(epochs_valid, ally_loss_valid, 'b--')
        plt.plot(epochs_train, advr_1_loss_train, 'g')
        plt.plot(epochs_valid, advr_1_loss_valid, 'g--')
        plt.plot(epochs_train, advr_2_loss_train, 'y')
        plt.plot(epochs_valid, advr_2_loss_valid, 'y--')
        plt.legend([
            'encoder train',
            'encoder valid',
            'ally train',
            'ally valid',
            'advr 1 train',
            'advr 1 valid',
            'advr 2 train',
            'advr 2 valid',
        ])
        plt.title("{}:{}/{} on {} training".format(model, k, num_nodes, expt))

        plot_location = 'plots/{}/{}_{}_{}_training_{}_{}.png'.format(
            expt, model, num_nodes, k, time_stamp, config_summary)
        sep()
        logging.info('Saving: {}'.format(plot_location))
        plt.savefig(plot_location)
        checkpoint_location = \
            'checkpoints/{}/{}_{}_{}_training_history_{}_{}.pkl'.format(
                expt, model, num_nodes, k, time_stamp, config_summary)
        logging.info('Saving: {}'.format(checkpoint_location))
        pkl.dump((
            epochs_train,
            epochs_valid,
            encd_loss_train,
            encd_loss_valid,
            ally_loss_train,
            ally_loss_valid,
            advr_1_loss_train,
            advr_1_loss_valid,
            advr_2_loss_train,
            advr_2_loss_valid,
        ), open(checkpoint_location, 'wb'))

        model_ckpt = 'checkpoints/{}/{}_{}_{}_torch_model_{}_{}.pkl'.format(
            expt, model, num_nodes, k, time_stamp, config_summary)
        logging.info('Saving: {}'.format(model_ckpt))
        torch.save(encoder, model_ckpt)
예제 #9
0
def main(
    model,
    time_stamp,
    device,
    ally_classes,
    advr_1_classes,
    advr_2_classes,
    encoding_dim,
    hidden_dim,
    leaky,
    test_size,
    batch_size,
    n_epochs,
    shuffle,
    lr_ally,
    lr_advr_1,
    lr_advr_2,
    expt,
    pca_ckpt,
    autoencoder_ckpt,
    encoder_ckpt,
):
    device = torch_device(device=device)

    X_normalized_train, X_normalized_valid,\
        y_ally_train, y_ally_valid, \
        y_advr_1_train, y_advr_1_valid, \
        y_advr_2_train, y_advr_2_valid = get_data(expt, test_size)

    pca = joblib.load(pca_ckpt)

    optim = torch.optim.Adam
    criterionBCEWithLogits = nn.BCEWithLogitsLoss()
    criterionCrossEntropy = nn.CrossEntropyLoss()

    h = {
        'epoch': {
            'train': [],
            'valid': [],
        },
        'pca': {
            'ally_train': [],
            'ally_valid': [],
            'advr_1_train': [],
            'advr_1_valid': [],
            'advr_2_train': [],
            'advr_2_valid': [],
        },
    }

    for _ in ['pca']:
        if _ == 'pca':
            dataset_train = utils.TensorDataset(
                torch.Tensor(pca.eval(X_normalized_train)),
                torch.Tensor(y_ally_train.reshape(-1, ally_classes)),
                torch.Tensor(y_advr_1_train.reshape(-1, advr_1_classes)),
                torch.Tensor(y_advr_2_train.reshape(-1, advr_2_classes)),
            )

            dataset_valid = utils.TensorDataset(
                torch.Tensor(pca.eval(X_normalized_valid)),
                torch.Tensor(y_ally_valid.reshape(-1, ally_classes)),
                torch.Tensor(y_advr_1_valid.reshape(-1, advr_1_classes)),
                torch.Tensor(y_advr_2_valid.reshape(-1, advr_2_classes)),
            )

        dataloader_train = torch.utils.data.DataLoader(dataset_train,
                                                       batch_size=batch_size,
                                                       shuffle=shuffle,
                                                       num_workers=1)

        dataloader_valid = torch.utils.data.DataLoader(dataset_valid,
                                                       batch_size=batch_size,
                                                       shuffle=shuffle,
                                                       num_workers=1)

        ally = DiscriminatorFCN(encoding_dim, hidden_dim, ally_classes,
                                leaky).to(device)
        advr_1 = DiscriminatorFCN(encoding_dim, hidden_dim, advr_1_classes,
                                  leaky).to(device)
        advr_2 = DiscriminatorFCN(encoding_dim, hidden_dim, advr_2_classes,
                                  leaky).to(device)

        ally.apply(weights_init)
        advr_1.apply(weights_init)
        advr_2.apply(weights_init)

        sep('{}:{}'.format(_, 'ally'))
        summary(ally, input_size=(1, encoding_dim))
        sep('{}:{}'.format(_, 'advr 1'))
        summary(advr_1, input_size=(1, encoding_dim))
        sep('{}:{}'.format(_, 'advr 2'))
        summary(advr_2, input_size=(1, encoding_dim))

        optimizer_ally = optim(ally.parameters(), lr=lr_ally)
        optimizer_advr_1 = optim(advr_1.parameters(), lr=lr_advr_1)
        optimizer_advr_2 = optim(advr_2.parameters(), lr=lr_advr_2)

        # adversary 1
        sep("adversary 1")
        logging.info('{} \t {} \t {}'.format(
            'Epoch',
            'Advr 1 Train',
            'Advr 1 Valid',
        ))

        for epoch in range(n_epochs):
            advr_1.train()

            nsamples = 0
            iloss_advr = 0
            for i, data in enumerate(dataloader_train, 0):
                X_train_torch = data[0].to(device)
                y_advr_train_torch = data[2].to(device)

                optimizer_advr_1.zero_grad()
                y_advr_train_hat_torch = advr_1(X_train_torch)

                loss_advr = criterionCrossEntropy(
                    y_advr_train_hat_torch,
                    torch.argmax(y_advr_train_torch, 1))
                loss_advr.backward()
                optimizer_advr_1.step()

                nsamples += 1
                iloss_advr += loss_advr.item()

            h[_]['advr_1_train'].append(iloss_advr / nsamples)

            if epoch % int(n_epochs / 10) != 0:
                continue

            advr_1.eval()

            nsamples = 0
            iloss_advr = 0
            correct = 0
            total = 0

            for i, data in enumerate(dataloader_valid, 0):
                X_valid_torch = data[0].to(device)
                y_advr_valid_torch = data[2].to(device)
                y_advr_valid_hat_torch = advr_1(X_valid_torch)

                valid_loss_advr = criterionCrossEntropy(
                    y_advr_valid_hat_torch,
                    torch.argmax(y_advr_valid_torch, 1))

                tmp, predicted = torch.max(y_advr_valid_hat_torch, 1)
                tmp, actual = torch.max(y_advr_valid_torch, 1)

                nsamples += 1
                iloss_advr += valid_loss_advr.item()
                total += actual.size(0)
                correct += (predicted == actual).sum().item()

            h[_]['advr_1_valid'].append(iloss_advr / nsamples)

            logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format(
                epoch, h[_]['advr_1_train'][-1], h[_]['advr_1_valid'][-1],
                correct / total))

        # adversary
        sep("adversary 2")
        logging.info('{} \t {} \t {}'.format(
            'Epoch',
            'Advr 2 Train',
            'Advr 2 Valid',
        ))

        for epoch in range(n_epochs):
            advr_2.train()

            nsamples = 0
            iloss_advr = 0
            for i, data in enumerate(dataloader_train, 0):
                X_train_torch = data[0].to(device)
                y_advr_train_torch = data[3].to(device)

                optimizer_advr_2.zero_grad()
                y_advr_train_hat_torch = advr_2(X_train_torch)

                loss_advr = criterionBCEWithLogits(y_advr_train_hat_torch,
                                                   y_advr_train_torch)
                loss_advr.backward()
                optimizer_advr_2.step()

                nsamples += 1
                iloss_advr += loss_advr.item()

            h[_]['advr_2_train'].append(iloss_advr / nsamples)

            if epoch % int(n_epochs / 10) != 0:
                continue

            advr_2.eval()

            nsamples = 0
            iloss_advr = 0
            correct = 0
            total = 0

            for i, data in enumerate(dataloader_valid, 0):
                X_valid_torch = data[0].to(device)
                y_advr_valid_torch = data[3].to(device)
                y_advr_valid_hat_torch = advr_2(X_valid_torch)

                valid_loss_advr = criterionBCEWithLogits(
                    y_advr_valid_hat_torch, y_advr_valid_torch)

                predicted = y_advr_valid_hat_torch > 0.5

                nsamples += 1
                iloss_advr += valid_loss_advr.item()
                total += y_advr_valid_torch.size(0)
                correct += (predicted == y_advr_valid_torch).sum().item()

            h[_]['advr_2_valid'].append(iloss_advr / nsamples)

            logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format(
                epoch, h[_]['advr_2_train'][-1], h[_]['advr_2_valid'][-1],
                correct / total))

        #ally
        sep("ally")
        logging.info('{} \t {} \t {} \t {}'.format(
            'Epoch',
            'Ally Train',
            'Ally Valid',
            'Accuracy',
        ))

        for epoch in range(n_epochs):
            ally.train()

            nsamples = 0
            iloss_ally = 0
            for i, data in enumerate(dataloader_train, 0):
                X_train_torch = data[0].to(device)
                y_ally_train_torch = data[1].to(device)

                optimizer_ally.zero_grad()
                y_ally_train_hat_torch = ally(X_train_torch)
                loss_ally = criterionBCEWithLogits(y_ally_train_hat_torch,
                                                   y_ally_train_torch)
                loss_ally.backward()
                optimizer_ally.step()

                nsamples += 1
                iloss_ally += loss_ally.item()
            if epoch not in h['epoch']['train']:
                h['epoch']['train'].append(epoch)
            h[_]['ally_train'].append(iloss_ally / nsamples)

            if epoch % int(n_epochs / 10) != 0:
                continue

            ally.eval()

            nsamples = 0
            iloss_ally = 0
            correct = 0
            total = 0

            for i, data in enumerate(dataloader_valid, 0):
                X_valid_torch = data[0].to(device)
                y_ally_valid_torch = data[1].to(device)
                y_ally_valid_hat_torch = ally(X_valid_torch)

                valid_loss_ally = criterionBCEWithLogits(
                    y_ally_valid_hat_torch, y_ally_valid_torch)

                predicted = y_ally_valid_hat_torch > 0.5

                nsamples += 1
                iloss_ally += valid_loss_ally.item()
                total += y_ally_valid_torch.size(0)
                correct += (predicted == y_ally_valid_torch).sum().item()

            if epoch not in h['epoch']['valid']:
                h['epoch']['valid'].append(epoch)
            h[_]['ally_valid'].append(iloss_ally / nsamples)

            logging.info('{} \t {:.8f} \t {:.8f} \t {:.8f}'.format(
                epoch, h[_]['ally_train'][-1], h[_]['ally_valid'][-1],
                correct / total))

    checkpoint_location = \
        'checkpoints/{}/{}_training_history_{}.pkl'.format(
            expt, model, time_stamp)
    sep()
    logging.info('Saving: {}'.format(checkpoint_location))
    pkl.dump(h, open(checkpoint_location, 'wb'))