Exemple #1
0
def test_discriminator(input, net):
    # torch.cuda.is_available = lambda : False

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    discriminator = Discriminator(net).to(device)

    # optimizer_generator = Adam(generator.parameters())
    # criterion = nn.BCELoss()

    discriminator.apply(weights_init)

    # Print the model
    print(discriminator)

    vector = discriminator(input)
    print('output of discriminator is:', vector.item())
Exemple #2
0
def train(dataloader,
          num_epochs,
          net,
          run_settings,
          learning_rate=0.0002,
          optimizerD='Adam'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create the nets
    generator = Generator(net).to(device)
    discriminator = Discriminator(net).to(device)

    # Apply the weights_init function to randomly initialize all weights
    generator.apply(weights_init)
    discriminator.apply(weights_init)

    # Initialize BCELoss function
    criterion = nn.BCELoss()

    # Create batch of latent vectors that we will use to visualize
    #  the progression of the generator
    fixed_noise = torch.randn(64, nz, 1, 1, device=device)

    # Establish convention for real and fake labels during training
    real_label = 1.
    fake_label = 0.

    beta1 = 0.5

    # Setup Adam optimizers for both G and D
    if optimizerD == 'SGD':
        optimizerD = optim.SGD(discriminator.parameters(), lr=learning_rate)
    else:
        optimizerD = optim.Adam(discriminator.parameters(),
                                lr=learning_rate,
                                betas=(beta1, 0.999))
    optimizerG = optim.Adam(generator.parameters(),
                            lr=learning_rate,
                            betas=(beta1, 0.999))

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    print("Starting Training Loop...")
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            ## Train with all-real batch
            discriminator.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size, ),
                               real_label,
                               dtype=torch.float,
                               device=device)
            # Forward pass real batch through D
            output = discriminator(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            # Generate fake image batch with G
            fake = generator(noise)
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = discriminator(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            # Calculate the gradients for this batch
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            generator.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = discriminator(fake).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, label)
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()
            # Update G
            optimizerG.step()

            # Output training stats
            if i % 3 == 0:
                print(
                    '[%d/%d][%d/%d]\t\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch + 1, num_epochs, i + 1, len(dataloader),
                       errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing by saving its output on fixed_noise
            if (iters %
                (len(dataloader) * 50) == 0) or ((epoch == num_epochs - 1) and
                                                 (i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake = generator(fixed_noise).detach().cpu()
                img_list.append(
                    vutils.make_grid(fake, padding=2, normalize=True))

            iters += 1

    print("finished")

    for i in range(len(img_list)):
        plt.imshow(np.transpose(img_list[i], (1, 2, 0)))
        plt.savefig('generated_images_' + str(i) + '.png')

    plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
    plt.savefig('generated_images_' + run_settings + '.png')

    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig('loss_graph_' + run_settings + '.png')
Exemple #3
0
def train(args):
    device_str = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device_str)

    gen = Generator(args.nz, 800)
    gen = gen.to(device)
    gen.apply(weights_init)

    discriminator = Discriminator(800)
    discriminator = discriminator.to(device)
    discriminator.apply(weights_init)

    bce = nn.BCELoss()
    bce = bce.to(device)

    galaxy_dataset = GalaxySet(args.data_path,
                               normalized=args.normalized,
                               out=args.out)
    loader = DataLoader(galaxy_dataset,
                        batch_size=args.bs,
                        shuffle=True,
                        num_workers=2,
                        drop_last=True)
    loader_iter = iter(loader)

    d_optimizer = Adam(discriminator.parameters(),
                       betas=(0.5, 0.999),
                       lr=args.lr)
    g_optimizer = Adam(gen.parameters(), betas=(0.5, 0.999), lr=args.lr)

    real_labels = to_var(torch.ones(args.bs), device_str)
    fake_labels = to_var(torch.zeros(args.bs), device_str)
    fixed_noise = to_var(torch.randn(1, args.nz), device_str)

    for i in tqdm(range(args.iters)):
        try:
            batch_data = loader_iter.next()
        except StopIteration:
            loader_iter = iter(loader)
            batch_data = loader_iter.next()

        batch_data = to_var(batch_data, device).unsqueeze(1)

        batch_data = batch_data[:, :, :1600:2]
        batch_data = batch_data.view(-1, 800)

        ### Train Discriminator ###

        d_optimizer.zero_grad()

        # train Infer with real
        pred_real = discriminator(batch_data)
        d_loss = bce(pred_real, real_labels)

        # train infer with fakes
        z = to_var(torch.randn((args.bs, args.nz)), device)
        fakes = gen(z)
        pred_fake = discriminator(fakes.detach())
        d_loss += bce(pred_fake, fake_labels)

        d_loss.backward()

        d_optimizer.step()

        ### Train Gen ###

        g_optimizer.zero_grad()

        z = to_var(torch.randn((args.bs, args.nz)), device)
        fakes = gen(z)
        pred_fake = discriminator(fakes)
        gen_loss = bce(pred_fake, real_labels)

        gen_loss.backward()
        g_optimizer.step()

        if i % 5000 == 0:
            print("Iteration %d >> g_loss: %.4f., d_loss: %.4f." %
                  (i, gen_loss, d_loss))
            torch.save(gen.state_dict(),
                       os.path.join(args.out, 'gen_%d.pkl' % 0))
            torch.save(discriminator.state_dict(),
                       os.path.join(args.out, 'disc_%d.pkl' % 0))
            gen.eval()
            fixed_fake = gen(fixed_noise).detach().cpu().numpy()
            real_data = batch_data[0].detach().cpu().numpy()
            gen.train()
            display_noise(fixed_fake.squeeze(),
                          os.path.join(args.out, "gen_sample_%d.png" % i))
            display_noise(real_data.squeeze(),
                          os.path.join(args.out, "real_%d.png" % 0))
Exemple #4
0
from nets import Generator, Discriminator
import input_pipe as ip

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

input_size = 28 * 28
hidden_size = 500

G = Generator((100, 500, 28 * 28), 'relu')
D = Discriminator((28 * 28, 500, 1), 'relu')

G.apply(weight_init.init_weights_normal)
D.apply(weight_init.init_weights_normal)

sample_size = int(1e6)
x_dataset = ip.MNIST_dataset()

z_dataset = ip.mv_gaussian(0, 1, mv_size=100, sample_size=sample_size)
z_dataset = torch.randn(sample_size, 100)

cuda = torch.cuda.is_available()

print("PID: ", os.getpid())
if (len(sys.argv) == 2):
    folder = sys.argv[1]
else:
    folder = raw_input("Folder name\n")
if (os.path.exists(folder)):