Пример #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_images', default=25)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--model', required=True)
    opts = parser.parse_args()

    # GPU
    device = 'cpu'
    if opts.cuda:
        if torch.cuda.is_available():
            device = 'cuda:0'
        else:
            print('CUDA not available!')

    # Load pretrained model
    netG = dcgan.Generator()
    netG.load_state_dict(torch.load(opts.model, map_location=device))
    netG = netG.to(device)
    netG.eval()

    # Start generation
    noise = torch.randn(opts.num_images, NOISE_LENGTH, device=device)
    fake = netG(noise)
    vutils.save_image(fake,
                      "samples/generate.png",
                      5,
                      normalize=True,
                      range=(-1, 1))
Пример #2
0
def test_output_shape():
    netG = dcgan.Generator()
    noise = torch.randn(2, 100)
    fake = netG(noise)
    assert fake.shape == torch.Size([2, 3, 64, 64])

    netD = dcgan.Critic()
    fake = torch.randn(2, 3, 64, 64)
    output = netD(fake)
    assert output.shape == torch.Size((2, ))
Пример #3
0
def main():
    set_random_seed()

    device, gpu_ids = util.get_available_devices()

    # Arguments
    opt = args.get_setup_args()

    # Number of channels in the training images
    nc = opt.channels
    # Size of z latent vector (i.e. size of generator input)
    nz = opt.latent_dim
    # Size of feature maps in generator
    ngf = 64

    def eval_fid(gen_images_path, eval_images_path):
        print("Calculating FID...")
        fid = fid_score.calculate_fid_given_paths(
            (gen_images_path, eval_images_path), opt.batch_size, device)
        return fid

    def evaluate(source_images_path, keep_images=True):
        dataset = datasets.ImageFolder(root=source_images_path,
                                       transform=transforms.Compose([
                                           transforms.Resize(
                                               (opt.img_size, opt.img_size)),
                                           transforms.ToTensor()
                                       ]))

        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=opt.num_workers)

        output_gen_images_path = os.path.join(opt.output_path, opt.version,
                                              opt.eval_mode)
        os.makedirs(output_gen_images_path, exist_ok=True)

        output_source_images_path = source_images_path + "_" + str(
            opt.img_size)

        source_images_available = True

        if (not os.path.exists(output_source_images_path)):
            os.makedirs(output_source_images_path)
            source_images_available = False

        images_done = 0
        for _, data in enumerate(dataloader, 0):
            images, _ = data
            batch_size = images.size(0)
            noise = torch.randn((batch_size, nz, 1, 1)).to(device)

            gen_images = netG(noise)
            for i in range(images_done, images_done + batch_size):
                vutils.save_image(gen_images[i - images_done, :, :, :],
                                  "{}/{}.jpg".format(output_gen_images_path,
                                                     i),
                                  normalize=True)
                if (not source_images_available):
                    vutils.save_image(images[i - images_done, :, :, :],
                                      "{}/{}.jpg".format(
                                          output_source_images_path, i),
                                      normalize=True)
            images_done += batch_size

        fid = eval_fid(output_gen_images_path, output_source_images_path)
        if (not keep_images):
            print("Deleting images generated for validation...")
            rmtree(output_gen_images_path)
        return fid

    test_images_path = os.path.join(opt.data_path, "test")
    val_images_path = os.path.join(opt.data_path, "val")
    model_path = os.path.join(opt.output_path, opt.version, opt.model_file)

    netG = dcgan.Generator(nc, nz, ngf).to(device)

    if (opt.model_file.endswith(".pt")):
        netG.load_state_dict(torch.load(model_path))
    elif (opt.model_file.endswith(".tar")):
        checkpoint = torch.load(model_path)
        netG.load_state_dict(checkpoint['g_state_dict'])

    netG.eval()

    if opt.eval_mode == "val":
        source_images_path = val_images_path
    else:
        source_images_path = test_images_path

    if opt.eval_mode == "val" or opt.eval_mode == "test":
        print("Evaluating model...")
        fid = evaluate(source_images_path)
        print("FID: {}".format(fid))
Пример #4
0
def train(dataloader,
          gradient_penalty,
          n_epochs,
          lr,
          n_critic=5,
          c=0.01,
          penalty_factor=10):
    def weights_init(m):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    netG = dcgan.Generator().to(device)
    netG.apply(weights_init)
    optG = optim.Adam(netG.parameters(), lr,
                      (0, 0.9)) if gradient_penalty else optim.RMSprop(
                          netG.parameters(), lr)

    netD = dcgan.Critic(gradient_penalty).to(device)
    netD.apply(weights_init)
    optD = optim.Adam(netD.parameters(), lr,
                      (0, 0.9)) if gradient_penalty else optim.RMSprop(
                          netD.parameters(), lr)

    noise = torch.randn(dataloader.batch_size, NOISE_LENGTH, device=device)

    batches_done = 0

    for epoch in range(n_epochs):
        # Train critic
        for i, (data, _) in enumerate(dataloader):
            data = data.to(device)
            real_batch_size = len(data)

            optD.zero_grad()
            noise.normal_()
            fake = netG(noise[:real_batch_size])
            # Gradients from G are not used, so detach to avoid computing them
            # Maximize (3) -> minimize its inverse
            lossD = -(netD(data).mean() - netD(fake.detach()).mean())
            if gradient_penalty:
                lossD += penalty_factor * compute_gradient_penalty(
                    data, fake, netD)

            lossD.backward()
            optD.step()
            # Clamp the weights
            if not gradient_penalty:
                for p in netD.parameters():
                    p.data.clamp_(-c, c)

            # Train generator
            if i % n_critic == 0:
                optG.zero_grad()
                noise.normal_()
                # Minimize EM distance
                lossG = -netD(netG(noise[:real_batch_size])).mean()
                lossG.backward()
                optG.step()
                print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                      (epoch, n_epochs, batches_done % len(dataloader),
                       len(dataloader), lossD.item(), lossG.item()))

            if batches_done % 100 == 0:
                fake = netG(NOISE_BASELINE)
                vutils.save_image(fake,
                                  "samples/%d.png" % batches_done,
                                  5,
                                  normalize=True,
                                  range=(-1, 1))

            batches_done += 1

        # Save the model
        torch.save(netG.state_dict(), 'pretrain/netG_epoch_%d.pth' % epoch)
        torch.save(netD.state_dict(), 'pretrain/netD_epoch_%d.pth' % epoch)
Пример #5
0
                               transforms.ToTensor()
                               #transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
                           ])
					)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(2))


###### model #######

ngf = opt.ngf
ndf = opt.ndf
nz = opt.nz
nc = 3

netG = dcgan.Generator(nc, ngf, nz, opt.fineSize)
netD = dcgan.Discriminator(nc, ndf, opt.hidden_size, opt.fineSize)
if(opt.cuda):
    netG.cuda()
    netD.cuda()

#### setup optimizer #####
optimizerG = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerD = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

criterion = nn.L1Loss()

####### Variables ##########
noise = torch.FloatTensor(opt.batchSize, opt.nz, 1, 1)
real = torch.FloatTensor(opt.batchSize, nc, opt.fineSize, opt.fineSize)
label = torch.FloatTensor(1)
Пример #6
0
def main():

    set_random_seed()

    #cuda = True if torch.cuda.is_available() else False
    device, gpu_ids = util.get_available_devices()

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    # Arguments
    opt = args.get_setup_args()

    # Number of channels in the training images
    nc = opt.channels
    # Size of z latent vector (i.e. size of generator input)
    nz = opt.latent_dim
    # Size of feature maps in generator
    ngf = 64
    # Size of feature maps in discriminator
    ndf = 64

    num_classes = opt.num_classes

    train_images_path = os.path.join(opt.data_path, "train")
    val_images_path = os.path.join(opt.data_path, "val")
    output_model_path = os.path.join(opt.output_path, opt.version)
    output_train_images_path = os.path.join(opt.output_path, opt.version,
                                            "train")
    output_sample_images_path = os.path.join(opt.output_path, opt.version,
                                             "sample")

    os.makedirs(output_train_images_path, exist_ok=True)
    os.makedirs(output_sample_images_path, exist_ok=True)

    # Initialize BCELoss function
    criterion = nn.BCELoss()

    # Initialize generator and discriminator
    netG = dcgan.Generator(nc, nz, ngf).to(device)
    netG.apply(weights_init)
    netD = dcgan.Discriminator(nc, ndf).to(device)
    netD.apply(weights_init)

    # Create batch of latent vectors to visualize
    # the progress of the generator
    # sample_noise = torch.randn(64, nz, 1, 1, device=device)

    # Establish convention for real and fake labels during training
    real_label = 1
    fake_label = 0

    # Setup Adam optimizers for both G and D
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr,
                            betas=(opt.b1, opt.b2))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr,
                            betas=(opt.b1, opt.b2))

    train_set = datasets.ImageFolder(root=train_images_path,
                                     transform=transforms.Compose([
                                         transforms.Resize(
                                             (opt.img_size, opt.img_size)),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5),
                                                              (0.5, 0.5, 0.5))
                                     ]))

    dataloader = torch.utils.data.DataLoader(train_set,
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=opt.num_workers)

    # ----------
    #  Training
    # ----------

    G_losses = []
    D_losses = []
    FIDs = []
    val_epochs = []

    def eval_fid(gen_images_path, eval_images_path):
        print("Calculating FID...")
        fid = fid_score.calculate_fid_given_paths(
            (gen_images_path, eval_images_path), opt.batch_size, device)
        return fid

    def validate(keep_images=True):
        val_set = datasets.ImageFolder(root=val_images_path,
                                       transform=transforms.Compose([
                                           transforms.Resize(
                                               (opt.img_size, opt.img_size)),
                                           transforms.ToTensor()
                                       ]))

        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 num_workers=opt.num_workers)

        output_images_path = os.path.join(opt.output_path, opt.version, "val")
        os.makedirs(output_images_path, exist_ok=True)

        output_source_images_path = val_images_path + "_" + str(opt.img_size)

        source_images_available = True

        if (not os.path.exists(output_source_images_path)):
            os.makedirs(output_source_images_path)
            source_images_available = False

        images_done = 0
        for _, data in enumerate(val_loader, 0):
            images, _ = data
            batch_size = images.size(0)
            noise = torch.randn((batch_size, nz, 1, 1)).to(device)

            gen_images = netG(noise)
            for i in range(images_done, images_done + batch_size):
                vutils.save_image(gen_images[i - images_done, :, :, :],
                                  "{}/{}.jpg".format(output_images_path, i),
                                  normalize=True)
                if (not source_images_available):
                    vutils.save_image(images[i - images_done, :, :, :],
                                      "{}/{}.jpg".format(
                                          output_source_images_path, i),
                                      normalize=True)
            images_done += batch_size

        fid = eval_fid(output_images_path, output_source_images_path)
        if (not keep_images):
            print("Deleting images generated for validation...")
            rmtree(output_images_path)
        return fid

    def sample_images(num_images, batches_done):
        # Sample noise
        z = torch.randn((num_classes * num_images, nz, 1, 1)).to(device)
        sample_imgs = netG(z)
        vutils.save_image(sample_imgs.data,
                          "{}/{}.png".format(output_sample_images_path,
                                             batches_done),
                          nrow=num_images,
                          padding=2,
                          normalize=True)

    def save_loss_plot(path):
        plt.figure(figsize=(10, 5))
        plt.title("Generator and Discriminator Loss During Training")
        plt.plot(G_losses, label="G")
        plt.plot(D_losses, label="D")
        plt.xlabel("iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(path)

    def save_fid_plot(FIDs, epochs, path):
        #N = len(FIDs)
        plt.figure(figsize=(10, 5))
        plt.title("FID on Validation Set")
        plt.plot(epochs, FIDs)
        plt.xlabel("epochs")
        plt.ylabel("FID")
        #plt.xticks([i * 49 for i in range(1, N+1)])
        plt.savefig(path)
        plt.close()

    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(1, opt.num_epochs + 1):
        # For each batch in the dataloader
        for i, data in enumerate(dataloader, 0):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD.zero_grad()
            # Format batch
            real_imgs = data[0].to(device)
            batch_size = real_imgs.size(0)
            label = torch.full((batch_size, ), real_label, device=device)

            # Forward pass real batch through D
            output = netD(real_imgs).view(-1)

            # Calculate loss on all-real batch
            errD_real = criterion(output, label)

            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(batch_size, nz, 1, 1, device=device)

            # Generate fake image batch with G
            fake = netG(noise)
            label.fill_(fake_label)

            # Classify all fake batch with D
            output = netD(fake.detach()).view(-1)

            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)

            # Calculate the gradients for this batch
            errD_fake.backward()
            D_G_z1 = output.mean().item()

            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = netD(fake).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, label)
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()
            # Update G
            optimizerG.step()

            # Save Losses for plotting
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Output training stats
            if i % opt.print_every == 0:
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f] [D(x): %.4f] [D(G(z)): %.4f / %.4f]"
                    % (epoch, opt.num_epochs, i, len(dataloader), errD.item(),
                       errG.item(), D_x, D_G_z1, D_G_z2))

            batches_done = epoch * len(dataloader) + i

            if (batches_done % opt.sample_interval
                    == 0) or ((epoch == opt.num_epochs - 1) and
                              (i == len(dataloader) - 1)):
                # Put G in eval mode
                netG.eval()

                with torch.no_grad():
                    sample_images(opt.num_sample_images, batches_done)
                vutils.save_image(fake.data[:25],
                                  "{}/{}.png".format(output_train_images_path,
                                                     batches_done),
                                  nrow=5,
                                  padding=2,
                                  normalize=True)

                # Put G back in train mode
                netG.train()

        # Save model checkpoint
        if (epoch != opt.num_epochs and epoch % opt.checkpoint_epochs == 0):
            print("Checkpoint at epoch {}".format(epoch))
            print("Saving generator model...")
            torch.save(
                netG.state_dict(),
                os.path.join(output_model_path,
                             "model_checkpoint_{}.pt".format(epoch)))
            print("Saving G & D loss plot...")
            save_loss_plot(
                os.path.join(opt.output_path, opt.version,
                             "loss_plot_{}.png".format(epoch)))

            print("Validating model...")
            netG.eval()
            with torch.no_grad():
                fid = validate(keep_images=False)
            print("Validation FID: {}".format(fid))
            with open(os.path.join(opt.output_path, opt.version, "FIDs.txt"),
                      "a") as f:
                f.write("Epoch: {}, FID: {}\n".format(epoch, fid))
            FIDs.append(fid)
            val_epochs.append(epoch)
            print("Saving FID plot...")
            save_fid_plot(
                FIDs, val_epochs,
                os.path.join(opt.output_path, opt.version,
                             "fid_plot_{}.png".format(epoch)))
            netG.train()

    print("Saving final generator model...")
    torch.save(netG.state_dict(), os.path.join(output_model_path, "model.pt"))
    print("Done!")

    print("Saving final G & D loss plot...")
    save_loss_plot(os.path.join(opt.output_path, opt.version, "loss_plot.png"))
    print("Done!")

    print("Validating final model...")
    netG.eval()
    with torch.no_grad():
        fid = validate()
    print("Final Validation FID: {}".format(fid))
    with open(os.path.join(opt.output_path, opt.version, "FIDs.txt"),
              "a") as f:
        f.write("Epoch: {}, FID: {}\n".format(epoch, fid))
    FIDs.append(fid)
    val_epochs.append(epoch)
    print("Saving final FID plot...")
    save_fid_plot(FIDs, val_epochs,
                  os.path.join(opt.output_path, opt.version, "fid_plot"))