def _main():
    print_gpu_details()
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
    train_root = args.train_path

    image_size = 256
    cropped_image_size = 256
    print("set image folder")
    train_set = dset.ImageFolder(root=train_root,
                                 transform=transforms.Compose([
                                     transforms.Resize(image_size),
                                     transforms.CenterCrop(cropped_image_size),
                                     transforms.ToTensor()
                                 ]))

    normalizer_clf = transforms.Compose([
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    normalizer_discriminator = transforms.Compose([
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    print('set data loader')
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

    # Network creation
    classifier = torch.load(args.classifier_path)
    classifier.eval()
    generator = Generator(gen_type=args.gen_type)
    discriminator = Discriminator(args.discriminator_norm, dis_type=args.gen_type)
    # init weights
    if args.generator_path is not None:
        generator.load_state_dict(torch.load(args.generator_path))
    else:
        generator.init_weights()
    if args.discriminator_path is not None:
        discriminator.load_state_dict(torch.load(args.discriminator_path))
    else:
        discriminator.init_weights()

    classifier.to(device)
    generator.to(device)
    discriminator.to(device)

    # losses + optimizers
    criterion_discriminator, criterion_generator = get_wgan_losses_fn()
    criterion_features = nn.L1Loss()
    criterion_diversity_n = nn.L1Loss()
    criterion_diversity_d = nn.L1Loss()
    generator_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999))
    discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999))

    num_of_epochs = args.epochs

    starting_time = time.time()
    iterations = 0
    # creating dirs for keeping models checkpoint, temp created images, and loss summary
    outputs_dir = os.path.join('wgan-gp_models', args.model_name)
    if not os.path.isdir(outputs_dir):
        os.makedirs(outputs_dir, exist_ok=True)
    temp_results_dir = os.path.join(outputs_dir, 'temp_results')
    if not os.path.isdir(temp_results_dir):
        os.mkdir(temp_results_dir)
    models_dir = os.path.join(outputs_dir, 'models_checkpoint')
    if not os.path.isdir(models_dir):
        os.mkdir(models_dir)
    writer = tensorboardX.SummaryWriter(os.path.join(outputs_dir, 'summaries'))

    z = torch.randn(args.batch_size, 128, 1, 1).to(device)  # a fixed noise for sampling
    z2 = torch.randn(args.batch_size, 128, 1, 1).to(device)  # a fixed noise for diversity sampling
    fixed_features = 0
    fixed_masks = 0
    fixed_features_diversity = 0
    first_iter = True
    print("Starting Training Loop...")
    for epoch in range(num_of_epochs):
        for data in train_loader:
            train_type = random.choices([1, 2], [args.train1_prob, 1-args.train1_prob]) # choose train type
            iterations += 1
            if iterations % 30 == 1:
                print('epoch:', epoch, ', iter', iterations, 'start, time =', time.time() - starting_time, 'seconds')
                starting_time = time.time()
            images, _ = data
            images = images.to(device)  # change to gpu tensor
            images_discriminator = normalizer_discriminator(images)
            images_clf = normalizer_clf(images)
            _, features = classifier(images_clf)
            if first_iter: # save batch of images to keep track of the model process
                first_iter = False
                fixed_features = [torch.clone(features[x]) for x in range(len(features))]
                fixed_masks = [torch.ones(features[x].shape, device=device) for x in range(len(features))]
                fixed_features_diversity = [torch.clone(features[x]) for x in range(len(features))]
                for i in range(len(features)):
                    for j in range(fixed_features_diversity[i].shape[0]):
                        fixed_features_diversity[i][j] = fixed_features_diversity[i][j % 8]
                grid = vutils.make_grid(images_discriminator, padding=2, normalize=True, nrow=8)
                vutils.save_image(grid, os.path.join(temp_results_dir, 'original_images.jpg'))
                orig_images_diversity = torch.clone(images_discriminator)
                for i in range(orig_images_diversity.shape[0]):
                    orig_images_diversity[i] = orig_images_diversity[i % 8]
                grid = vutils.make_grid(orig_images_diversity, padding=2, normalize=True, nrow=8)
                vutils.save_image(grid, os.path.join(temp_results_dir, 'original_images_diversity.jpg'))
            # Select a features layer to train on
            features_to_train = random.randint(1, len(features) - 2) if args.fixed_layer is None else args.fixed_layer
            # Set masks
            masks = [features[i].clone() for i in range(len(features))]
            setMasksPart1(masks, device, features_to_train) if train_type == 1 else setMasksPart2(masks, device, features_to_train)
            discriminator_loss_dict = train_discriminator(generator, discriminator, criterion_discriminator, discriminator_optimizer, images_discriminator, features, masks)
            for k, v in discriminator_loss_dict.items():
                writer.add_scalar('D/%s' % k, v.data.cpu().numpy(), global_step=iterations)
                if iterations % 30 == 1:
                    print('{}: {:.6f}'.format(k, v))
            if iterations % args.discriminator_steps == 1:
                generator_loss_dict = train_generator(generator, discriminator, criterion_generator, generator_optimizer, images.shape[0], features,
                                                      criterion_features, features_to_train, classifier, normalizer_clf, criterion_diversity_n,
                                                      criterion_diversity_d, masks, train_type)

                for k, v in generator_loss_dict.items():
                    writer.add_scalar('G/%s' % k, v.data.cpu().numpy(), global_step=iterations//5 + 1)
                    if iterations % 30 == 1:
                        print('{}: {:.6f}'.format(k, v))

            # Save generator and discriminator weights every 1000 iterations
            if iterations % 1000 == 1:
                torch.save(generator.state_dict(), models_dir + '/' + args.model_name + 'G')
                torch.save(discriminator.state_dict(), models_dir + '/' + args.model_name + 'D')
            # Save temp results
            if args.keep_temp_results:
                if iterations < 10000 and iterations % 1000 == 1 or iterations % 2000 == 1:
                    # regular sampling (batch of different images)
                    first_features = True
                    fake_images = None
                    fake_images_diversity = None
                    for i in range(1, 5):
                        one_layer_mask = isolate_layer(fixed_masks, i, device)
                        if first_features:
                            first_features = False
                            fake_images = sample(generator, z, fixed_features, one_layer_mask)
                            fake_images_diversity = sample(generator, z, fixed_features_diversity, one_layer_mask)
                        else:
                            tmp_fake_images = sample(generator, z, fixed_features, one_layer_mask)
                            fake_images = torch.vstack((fake_images, tmp_fake_images))
                            tmp_fake_images = sample(generator, z2, fixed_features_diversity, one_layer_mask)
                            fake_images_diversity = torch.vstack((fake_images_diversity, tmp_fake_images))
                    grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=8)
                    vutils.save_image(grid, os.path.join(temp_results_dir, 'res_iter_{}.jpg'.format(iterations // 1000)))
                    # diversity sampling (8 different images each with few different noises)
                    grid = vutils.make_grid(fake_images_diversity, padding=2, normalize=True, nrow=8)
                    vutils.save_image(grid, os.path.join(temp_results_dir, 'div_iter_{}.jpg'.format(iterations // 1000)))

                if iterations % 20000 == 1:
                    torch.save(generator.state_dict(), models_dir + '/' + args.model_name + 'G_' + str(iterations // 15000))
                    torch.save(discriminator.state_dict(), models_dir + '/' + args.model_name + 'D_' + str(iterations // 15000))
Example #2
0
if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(IMAGE_SIZE),
    torchvision.transforms.CenterCrop(IMAGE_SIZE),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
fixed_noise = torch.randn(BATCH_SIZE, Z_SIZE, 1, 1, device=device)
netG = Generator(IMAGE_SIZE, IMAGE_CHANNELS, Z_SIZE).to(device)
netD = Discriminator(IMAGE_SIZE, IMAGE_CHANNELS).to(device)
netG.init_weights()
netD.init_weights()
criterion = nn.BCELoss()

optimizerD = optim.Adam(netD.parameters(),
                        lr=LEARNING_RATE,
                        betas=(BETA1, 0.999))
optimizerG = optim.Adam(netG.parameters(),
                        lr=LEARNING_RATE,
                        betas=(BETA1, 0.999))

print('Loading dataset...')
images_dataset = ImageDataset(IMAGES_PATH, transforms, REAL_LABEL, device,
                              '*.*')
real_dataset = SplitDataset(images_dataset, {'train': 0.8, 'validation': 0.2})
real_dataset.select('train')
real_data_loader = DataLoader(real_dataset,