Esempio n. 1
0
def main(args):
    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")
    results = []
    columns = []

    for num_quadrant_inputs in args.num_quadrant_inputs:
        # adds an s in case of plural quadrants
        maybes = "s" if num_quadrant_inputs > 1 else ""

        print("Training with {} quadrant{} as input...".format(
            num_quadrant_inputs, maybes))

        # Dataset
        datasets, dataloaders, dataset_sizes = get_data(
            num_quadrant_inputs=num_quadrant_inputs, batch_size=128)

        # Train baseline
        baseline_net = baseline.train(
            device=device,
            dataloaders=dataloaders,
            dataset_sizes=dataset_sizes,
            learning_rate=args.learning_rate,
            num_epochs=args.num_epochs,
            early_stop_patience=args.early_stop_patience,
            model_path="baseline_net_q{}.pth".format(num_quadrant_inputs),
        )

        # Train CVAE
        cvae_net = cvae.train(
            device=device,
            dataloaders=dataloaders,
            dataset_sizes=dataset_sizes,
            learning_rate=args.learning_rate,
            num_epochs=args.num_epochs,
            early_stop_patience=args.early_stop_patience,
            model_path="cvae_net_q{}.pth".format(num_quadrant_inputs),
            pre_trained_baseline_net=baseline_net,
        )

        # Visualize conditional predictions
        visualize(
            device=device,
            num_quadrant_inputs=num_quadrant_inputs,
            pre_trained_baseline=baseline_net,
            pre_trained_cvae=cvae_net,
            num_images=args.num_images,
            num_samples=args.num_samples,
            image_path="cvae_plot_q{}.png".format(num_quadrant_inputs),
        )

        # Retrieve conditional log likelihood
        df = generate_table(
            device=device,
            num_quadrant_inputs=num_quadrant_inputs,
            pre_trained_baseline=baseline_net,
            pre_trained_cvae=cvae_net,
            num_particles=args.num_particles,
            col_name="{} quadrant{}".format(num_quadrant_inputs, maybes),
        )
        results.append(df)
        columns.append("{} quadrant{}".format(num_quadrant_inputs, maybes))

    results = pd.concat(results, axis=1, ignore_index=True)
    results.columns = columns
    results.loc["Performance gap", :] = results.iloc[0, :] - results.iloc[1, :]
    results.to_csv("results.csv")
                       '%s-sampled%d.png' % (variant, epoch),
                       nrow=10)

elif variant == 'cvae':
    encoder = neural_net.Encoder(10 + 28 * 28, 256, embedding_size)
    encoder.apply(neural_net.init_weights)

    decoder = neural_net.Decoder(10 + embedding_size, 256, 28 * 28)
    decoder.apply(neural_net.init_weights)

    criterion = neural_net.binomial_loss
    params = list(encoder.parameters()) + list(decoder.parameters())
    optimizer = torch.optim.Adam(params, lr=1e-3)

    t = torch.randn(100, embedding_size)
    c = torch.LongTensor(sum([[i] * 10 for i in range(10)], []))

    for epoch in range(1, 10 + 1):
        print('Epoch %d:' % epoch)
        losses = cvae.train(train_loader, encoder, decoder, criterion,
                            optimizer)
        print(np.mean(losses))
        with torch.no_grad():

            images = torch.sigmoid(
                decoder(torch.cat([F.one_hot(c, num_classes=10).float(), t],
                                  1)))
            save_image(images.view(100, 1, 28, 28),
                       '%s-sampled%d.png' % (variant, epoch),
                       nrow=10)