Пример #1
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        raise Exception('This is not implemented')
        encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))

    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)

    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    adversarial_loss = nn.BCELoss()

    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()
        adversarial_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()

    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    reverse_cycle_optimizer = optim.Adam(
        list(encoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    generator_optimizer = optim.Adam(
        list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    discriminator_optimizer = optim.Adam(
        list(discriminator.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(reverse_cycle_optimizer, step_size=80, gamma=0.1)
    generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=80, gamma=0.1)
    discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=80, gamma=0.1)

    # Used later to define discriminator ground truths
    Tensor = torch.cuda.FloatTensor if FLAGS.cuda else torch.FloatTensor

    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            headers = ['Epoch', 'Iteration', 'Reconstruction_loss', 'KL_divergence_loss', 'Reverse_cycle_loss']

            if FLAGS.forward_gan:
              headers.extend(['Generator_forward_loss', 'Discriminator_forward_loss'])

            if FLAGS.reverse_gan:
              headers.extend(['Generator_reverse_loss', 'Discriminator_reverse_loss'])

            log.write('\t'.join(headers) + '\n')

    # load data set and create data loader instance
    print('Loading CIFAR paired dataset...')
    paired_cifar = CIFAR_Paired(root='cifar', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(paired_cifar, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))

    # Save a batch of images to use for visualization
    image_sample_1, image_sample_2, _ = next(loader)
    image_sample_3, _, _ = next(loader)

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print('Epoch #' + str(epoch) + '..........................................................................')

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()
        generator_scheduler.step()
        discriminator_scheduler.step()

        for iteration in range(int(len(paired_cifar) / FLAGS.batch_size)):
            # Adversarial ground truths
            valid = Variable(Tensor(FLAGS.batch_size, 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(FLAGS.batch_size, 1).fill_(0.0), requires_grad=False)

            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(Variable(X_1))
            style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
            )
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(Variable(X_2))
            style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp())
            )
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # A-1. Discriminator training during forward cycle
            if FLAGS.forward_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_f_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_f_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_f_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_f_loss.backward()

              discriminator_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            # B-1. Discriminator training during reverse cycle
            if FLAGS.reverse_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_r_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_r_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_r_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_r_loss.backward()

              discriminator_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0]))

                if FLAGS.forward_gan:
                  print('Generator F loss: ' + str(gen_f_loss.data.storage().tolist()[0]))
                  print('Discriminator F loss: ' + str(dis_f_loss.data.storage().tolist()[0]))

                if FLAGS.reverse_gan:
                  print('Generator R loss: ' + str(gen_r_loss.data.storage().tolist()[0]))
                  print('Discriminator R loss: ' + str(dis_r_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                row = []

                row.append(epoch)
                row.append(iteration)
                row.append(reconstruction_error.data.storage().tolist()[0])
                row.append(kl_divergence_error.data.storage().tolist()[0])
                row.append(reverse_cycle_loss.data.storage().tolist()[0])

                if FLAGS.forward_gan:
                  row.append(gen_f_loss.data.storage().tolist()[0])
                  row.append(dis_f_loss.data.storage().tolist()[0])

                if FLAGS.reverse_gan:
                  row.append(gen_r_loss.data.storage().tolist()[0])
                  row.append(dis_r_loss.data.storage().tolist()[0])

                row = [str(x) for x in row]
                log.write('\t'.join(row) + '\n')

            # write to tensorboard
            writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.forward_gan:
              writer.add_scalar('Generator F loss', gen_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator F loss', dis_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.reverse_gan:
              writer.add_scalar('Generator R loss', gen_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator R loss', dis_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))

            """
            save reconstructed images and style swapped image generations to check progress
            """

            X_1.copy_(image_sample_1)
            X_2.copy_(image_sample_2)
            X_3.copy_(image_sample_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              image_batch = np.concatenate((image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_x = np.concatenate((reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              style_batch = np.concatenate((style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_style = np.concatenate((reconstructed_style, reconstructed_style, reconstructed_style), axis=3)
            imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
Пример #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_imgs', type=str, help='dataset path')
    parser.add_argument('--mask_imgs', type=str, help='dataset path')
    parser.add_argument('--log_dir',
                        type=str,
                        default='log',
                        help='Name of the log folder')
    parser.add_argument('--save_models',
                        type=bool,
                        default=True,
                        help='Set True if you want to save trained models')
    parser.add_argument('--pre_trained_model_path',
                        type=str,
                        default=None,
                        help='Pre-trained model path')
    parser.add_argument('--pre_trained_model_epoch',
                        type=str,
                        default=None,
                        help='Pre-trained model epoch e.g 200')
    parser.add_argument('--train_imgs_path',
                        type=str,
                        default='C:/Users/motur/coco/images/train2017',
                        help='Path to training images')
    parser.add_argument(
        '--train_annotation_path',
        type=str,
        default='C:/Users/motur/coco/annotations/instances_train2017.json',
        help='Path to annotation file, .json file')
    parser.add_argument('--category_names',
                        type=str,
                        default='giraffe,elephant,zebra,sheep,cow,bear',
                        help='List of categories in MS-COCO dataset')
    parser.add_argument('--num_test_img',
                        type=int,
                        default=16,
                        help='Number of images saved during training')
    parser.add_argument('--img_size',
                        type=int,
                        default=256,
                        help='Generated image size')
    parser.add_argument(
        '--local_patch_size',
        type=int,
        default=256,
        help='Image size of instance images after interpolation')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Mini-batch size')
    parser.add_argument('--train_epoch',
                        type=int,
                        default=20,
                        help='Maximum training epoch')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='Initial learning rate')
    parser.add_argument('--optim_step_size',
                        type=int,
                        default=80,
                        help='Learning rate decay step size')
    parser.add_argument('--optim_gamma',
                        type=float,
                        default=0.5,
                        help='Learning rate decay ratio')
    parser.add_argument(
        '--critic_iter',
        type=int,
        default=5,
        help='Number of discriminator update against each generator update')
    parser.add_argument('--noise_size',
                        type=int,
                        default=128,
                        help='Noise vector size')
    parser.add_argument('--lambda_FM',
                        type=float,
                        default=1,
                        help='Trade-off param for feature matching loss')
    parser.add_argument('--lambda_recon',
                        type=float,
                        default=0.00001,
                        help='Trade-off param for reconstruction loss')
    parser.add_argument('--num_res_blocks',
                        type=int,
                        default=5,
                        help='Number of residual block in generator network')
    parser.add_argument(
        '--trade_off_G',
        type=float,
        default=0.1,
        help=
        'Trade-off parameter which controls gradient flow to generator from D_local and D_glob'
    )

    opt = parser.parse_args()
    print(opt)

    #Create log folder
    root = 'result_fg/' + opt.category_names + '/'
    model = 'coco_model_'
    result_folder_name = 'images_' + opt.log_dir
    model_folder_name = 'models_' + opt.log_dir
    if not os.path.isdir(root):
        os.makedirs(root)
    if not os.path.isdir(root + result_folder_name):
        os.makedirs(root + result_folder_name)
    if not os.path.isdir(root + model_folder_name):
        os.makedirs(root + model_folder_name)

    #Save the script
    copyfile(os.path.basename(__file__),
             root + result_folder_name + '/' + os.path.basename(__file__))

    #Define transformation for dataset images - e.g scaling
    transform = transforms.Compose([
        transforms.Scale((opt.img_size, opt.img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    #Load dataset
    category_names = opt.category_names.split(',')
    allmasks = sorted(
        glob.glob(os.path.join(opt.mask_imgs, '**', '*.png'), recursive=True))
    print('Number of masks: %d' % len(allmasks))
    dataset = chairs(imfile=opt.train_imgs,
                     mfiles=allmasks,
                     category_names=category_names,
                     transform=transform,
                     final_img_size=opt.img_size)

    #Discard images contain very small instances
    # dataset.discard_small(min_area=0.03, max_area=1)

    #Define data loader
    train_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True)

    #For evaluation define fixed masks and noises
    data_iter = iter(train_loader)
    sample_batched = data_iter.next()
    x_fixed = sample_batched['image'][0:opt.num_test_img]
    x_fixed = Variable(x_fixed.cuda())
    y_fixed = sample_batched['single_fg_mask'][0:opt.num_test_img]
    y_fixed = Variable(y_fixed.cuda())
    z_fixed = torch.randn((opt.num_test_img, opt.noise_size))
    z_fixed = Variable(z_fixed.cuda())

    #Define networks
    G_fg = Generator_FG(z_dim=opt.noise_size,
                        label_channel=len(category_names),
                        num_res_blocks=opt.num_res_blocks)
    D_glob = Discriminator(channels=3 + len(category_names))
    D_instance = Discriminator(channels=3 + len(category_names),
                               input_size=opt.local_patch_size)
    G_fg.cuda()
    D_glob.cuda()
    D_instance.cuda()

    #Load parameters from pre-trained models
    if opt.pre_trained_model_path != None and opt.pre_trained_model_epoch != None:
        try:
            G_fg.load_state_dict(
                torch.load(opt.pre_trained_model_path + 'G_fg_epoch_' +
                           opt.pre_trained_model_epoch))
            D_glob.load_state_dict(
                torch.load(opt.pre_trained_model_path + 'D_glob_epoch_' +
                           opt.pre_trained_model_epoch))
            D_instance.load_state_dict(
                torch.load(opt.pre_trained_model_path + 'D_local_epoch_' +
                           opt.pre_trained_model_epoch))
            print('Parameters are loaded!')
        except:
            print('Error: Pre-trained parameters are not loaded!')
            pass

    #Define interpolation operation
    up_instance = nn.Upsample(size=(opt.local_patch_size,
                                    opt.local_patch_size),
                              mode='bilinear')

    #Define pooling operation for the case that image size and local patch size are mismatched
    pooling_instance = nn.Sequential()
    if opt.local_patch_size != opt.img_size:
        pooling_instance.add_module(
            '0', nn.AvgPool2d(int(opt.img_size / opt.local_patch_size)))

    #Define training loss function - binary cross entropy
    BCE_loss = nn.BCELoss()

    #Define feature matching loss
    criterionVGG = VGGLoss()
    criterionVGG = criterionVGG.cuda()

    #Define optimizer
    G_local_optimizer = optim.Adam(G_fg.parameters(),
                                   lr=opt.lr,
                                   betas=(0.0, 0.9))
    D_local_optimizer = optim.Adam(
        list(filter(lambda p: p.requires_grad, D_glob.parameters())) +
        list(filter(lambda p: p.requires_grad, D_instance.parameters())),
        lr=opt.lr,
        betas=(0.0, 0.9))
    #Deine learning rate scheduler
    scheduler_G = lr_scheduler.StepLR(G_local_optimizer,
                                      step_size=opt.optim_step_size,
                                      gamma=opt.optim_gamma)
    scheduler_D = lr_scheduler.StepLR(D_local_optimizer,
                                      step_size=opt.optim_step_size,
                                      gamma=opt.optim_gamma)

    #----------------------------TRAIN-----------------------------------------
    print('training start!')
    start_time = time.time()

    for epoch in range(opt.train_epoch):
        epoch_start_time = time.time()

        scheduler_G.step()
        scheduler_D.step()

        D_local_losses = []
        G_local_losses = []

        y_real_ = torch.ones(opt.batch_size)
        y_fake_ = torch.zeros(opt.batch_size)
        y_real_, y_fake_ = Variable(y_real_.cuda()), Variable(y_fake_.cuda())

        data_iter = iter(train_loader)
        num_iter = 0
        while num_iter < len(train_loader):

            j = 0
            while j < opt.critic_iter and num_iter < len(train_loader):
                j += 1
                sample_batched = data_iter.next()
                num_iter += 1
                x_ = sample_batched['image']
                y_ = sample_batched['single_fg_mask']
                fg_mask = sample_batched['seg_mask']

                y_instances = sample_batched['mask_instance']
                bbox = sample_batched['bbox']

                mini_batch = x_.size()[0]
                if mini_batch != opt.batch_size:
                    break

                #Update discriminators - D
                #Real examples
                D_glob.zero_grad()
                D_instance.zero_grad()

                x_, y_ = Variable(x_.cuda()), Variable(y_.cuda())
                fg_mask = Variable(fg_mask.cuda())
                y_reduced = torch.sum(y_,
                                      1).clamp(0,
                                               1).view(y_.size(0), 1,
                                                       opt.img_size,
                                                       opt.img_size)

                x_d = torch.cat([x_, fg_mask], 1)

                x_instances = torch.zeros(
                    (opt.batch_size, 3, opt.local_patch_size,
                     opt.local_patch_size))
                x_instances = Variable(x_instances.cuda())
                y_instances = Variable(y_instances.cuda())
                y_instances = pooling_instance(y_instances)
                G_instances = torch.zeros(
                    (opt.batch_size, 3, opt.local_patch_size,
                     opt.local_patch_size))
                G_instances = Variable(G_instances.cuda())

                #Obtain instances
                for t in range(x_d.size()[0]):
                    x_instance = x_[t, 0:3, bbox[0][t]:bbox[1][t],
                                    bbox[2][t]:bbox[3][t]]
                    x_instance = x_instance.contiguous().view(
                        1,
                        x_instance.size()[0],
                        x_instance.size()[1],
                        x_instance.size()[2])
                    x_instances[t] = up_instance(x_instance)

                D_result_instance = D_instance(
                    torch.cat([x_instances, y_instances], 1)).squeeze()
                D_result = D_glob(x_d).squeeze()
                D_real_loss = BCE_loss(D_result, y_real_) + BCE_loss(
                    D_result_instance, y_real_)
                D_real_loss.backward()

                #Fake examples
                z_ = torch.randn((mini_batch, opt.noise_size))
                z_ = Variable(z_.cuda())

                #Generate fake images
                G_fg_result = G_fg(z_, y_, torch.mul(x_, (1 - y_reduced)))
                G_result_d = torch.cat([G_fg_result, fg_mask], 1)

                #Obtain fake instances
                for t in range(x_d.size()[0]):
                    G_instance = G_result_d[t, 0:3, bbox[0][t]:bbox[1][t],
                                            bbox[2][t]:bbox[3][t]]
                    G_instance = G_instance.contiguous().view(
                        1,
                        G_instance.size()[0],
                        G_instance.size()[1],
                        G_instance.size()[2])
                    G_instances[t] = up_instance(G_instance)

                D_result_instance = D_instance(
                    torch.cat([G_instances, y_instances],
                              1).detach()).squeeze()
                D_result = D_glob(G_result_d.detach()).squeeze()
                D_fake_loss = BCE_loss(D_result, y_fake_) + BCE_loss(
                    D_result_instance, y_fake_)
                D_fake_loss.backward()
                D_local_optimizer.step()

                D_train_loss = D_real_loss + D_fake_loss
                D_local_losses.append(D_train_loss.data)

            if mini_batch != opt.batch_size:
                break

            #Update generator G
            G_fg.zero_grad()
            D_result = D_glob(G_result_d).squeeze()
            D_result_instance = D_instance(
                torch.cat([G_instances, y_instances], 1)).squeeze()
            G_train_loss = (1 - opt.trade_off_G) * BCE_loss(
                D_result, y_real_) + opt.trade_off_G * BCE_loss(
                    D_result_instance, y_real_)

            #Feature matching loss between generated image and corresponding ground truth
            FM_loss = criterionVGG(G_fg_result, x_)

            #Reconstruction loss
            Recon_loss = mse_loss(torch.mul(x_, (1 - y_reduced)),
                                  torch.mul(G_fg_result, (1 - y_reduced)))

            total_loss = G_train_loss + opt.lambda_FM * FM_loss + opt.lambda_recon * Recon_loss
            total_loss.backward()
            G_local_optimizer.step()
            G_local_losses.append(G_train_loss.data)

            print('loss_d: %.3f, loss_g: %.3f' %
                  (D_train_loss.data, G_train_loss.data))
            if (num_iter % 100) == 0:
                print('%d - %d complete!' % ((epoch + 1), num_iter))
                print(result_folder_name)

        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time
        print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' %
              ((epoch + 1), opt.train_epoch, per_epoch_ptime,
               torch.mean(torch.FloatTensor(D_local_losses)),
               torch.mean(torch.FloatTensor(G_local_losses))))

        #Save images
        G_fg.eval()

        if epoch == 0:
            show_result_rgb((epoch + 1),
                            x_fixed,
                            save=True,
                            path=root + result_folder_name + '/' + model +
                            str(epoch + 1) + '_gt.png')
            for t in range(y_fixed.size()[1]):
                show_result_rgb((epoch + 1),
                                y_fixed[:, t:t + 1, :, :],
                                save=True,
                                path=root + result_folder_name + '/' + model +
                                str(epoch + 1) + '_' + str(t) + '_masked.png')

        show_result_rgb(
            (epoch + 1),
            G_fg(
                z_fixed, y_fixed,
                torch.mul(x_fixed, (1 - torch.sum(y_fixed, 1).view(
                    y_fixed.size(0), 1, opt.img_size, opt.img_size)))),
            save=True,
            path=root + result_folder_name + '/' + model + str(epoch + 1) +
            '_fg.png')
        G_fg.train()

        #Save model params
        if opt.save_models and (epoch > 11 and epoch % 10 == 0):
            torch.save(
                G_fg.state_dict(), root + model_folder_name + '/' + model +
                'G_fg_epoch_' + str(epoch) + '.pth')
            torch.save(
                D_glob.state_dict(), root + model_folder_name + '/' + model +
                'D_glob_epoch_' + str(epoch) + '.pth')
            torch.save(
                D_instance.state_dict(), root + model_folder_name + '/' +
                model + 'D_local_epoch_' + str(epoch) + '.pth')

    torch.save(
        G_fg.state_dict(), root + model_folder_name + '/' + model +
        'G_fg_epoch_' + str(epoch) + '.pth')
    torch.save(
        D_glob.state_dict(), root + model_folder_name + '/' + model +
        'D_glob_epoch_' + str(epoch) + '.pth')
    torch.save(
        D_instance.state_dict(), root + model_folder_name + '/' + model +
        'D_local_epoch_' + str(epoch) + '.pth')
    end_time = time.time()
    total_ptime = end_time - start_time
    print("Training finish!... save training results")
    print('Training time: ' + str(total_ptime))
Пример #3
0
            transforms.CenterCrop(opt.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])),
                              batch_size=opt.batch_size,
                              shuffle=True)

generator = Generator(opt.latent_dim, opt.channels)
discriminator = Discriminator(opt.n_classes, opt.channels)

adversarial_loss = nn.BCELoss()
classifier_loss = nn.CrossEntropyLoss()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    classifier_loss.cuda()
    torch.cuda.set_device(opt.gpu_ids)

print(generator, discriminator)

optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=opt.lr,
                               betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=opt.lr,
                               betas=(opt.b1, opt.b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
Пример #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--log_dir',
                        type=str,
                        default='log',
                        help='Name of the log folder')
    parser.add_argument('--save_models',
                        type=bool,
                        default=True,
                        help='Set True if you want to save trained models')
    parser.add_argument('--pre_trained_model_path',
                        type=str,
                        default=None,
                        help='Pre-trained model path')
    parser.add_argument('--pre_trained_model_epoch',
                        type=str,
                        default=None,
                        help='Pre-trained model epoch e.g 200')
    parser.add_argument('--train_imgs_path',
                        type=str,
                        default='/mnt/sdb/data/COCO/train2017',
                        help='Path to training images')
    parser.add_argument(
        '--train_annotation_path',
        type=str,
        default='/mnt/sdb/data/COCO/annotations/instances_train2017.json',
        help='Path to annotation file, .json file')
    parser.add_argument('--category_names',
                        type=str,
                        default='giraffe,elephant,zebra,sheep,cow,bear',
                        help='List of categories in MS-COCO dataset')
    parser.add_argument('--num_test_img',
                        type=int,
                        default=4,
                        help='Number of images saved during training')
    parser.add_argument('--img_size',
                        type=int,
                        default=256,
                        help='Generated image size')
    parser.add_argument(
        '--local_patch_size',
        type=int,
        default=256,
        help='Image size of instance images after interpolation')
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help='Mini-batch size')
    parser.add_argument('--train_epoch',
                        type=int,
                        default=400,
                        help='Maximum training epoch')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='Initial learning rate')
    parser.add_argument('--optim_step_size',
                        type=int,
                        default=80,
                        help='Learning rate decay step size')
    parser.add_argument('--optim_gamma',
                        type=float,
                        default=0.5,
                        help='Learning rate decay ratio')
    parser.add_argument(
        '--critic_iter',
        type=int,
        default=5,
        help='Number of discriminator update against each generator update')
    parser.add_argument('--noise_size',
                        type=int,
                        default=256,
                        help='Noise vector size')
    parser.add_argument('--lambda_FM',
                        type=float,
                        default=1,
                        help='Trade-off param for feature matching loss')
    parser.add_argument('--lambda_branch',
                        type=float,
                        default=100,
                        help='Trade-off param for reconstruction loss')
    parser.add_argument(
        '--num_res_blocks',
        type=int,
        default=2,
        help='Number of residual block in generator shared part')
    parser.add_argument('--num_res_blocks_fg',
                        type=int,
                        default=2,
                        help='Number of residual block in non-bg branch')
    parser.add_argument('--num_res_blocks_bg',
                        type=int,
                        default=0,
                        help='Number of residual block in generator bg branch')

    opt = parser.parse_args()
    print(opt)

    #Create log folder
    root = 'result_bg/'
    model = 'coco_model_'
    result_folder_name = 'images_' + opt.log_dir
    model_folder_name = 'models_' + opt.log_dir
    if not os.path.isdir(root):
        os.mkdir(root)
    if not os.path.isdir(root + result_folder_name):
        os.mkdir(root + result_folder_name)
    if not os.path.isdir(root + model_folder_name):
        os.mkdir(root + model_folder_name)

    #Save the script
    copyfile(os.path.basename(__file__),
             root + result_folder_name + '/' + os.path.basename(__file__))

    #Define transformation for dataset images - e.g scaling
    transform = transforms.Compose([
        transforms.Scale((opt.img_size, opt.img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    #Load dataset
    category_names = opt.category_names.split(',')
    dataset = CocoData(root=opt.train_imgs_path,
                       annFile=opt.train_annotation_path,
                       category_names=category_names,
                       transform=transform,
                       final_img_size=opt.img_size)

    #Discard images contain very small instances
    dataset.discard_small(min_area=0.0, max_area=1)
    #dataset.discard_bad_examples('bad_examples_list.txt')

    #Define data loader
    train_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True)

    #For evaluation define fixed masks and noises
    data_iter = iter(train_loader)
    sample_batched = data_iter.next()
    y_fixed = sample_batched['seg_mask'][0:opt.num_test_img]
    y_fixed = Variable(y_fixed.cuda())
    z_fixed = torch.randn((opt.num_test_img, opt.noise_size))
    z_fixed = Variable(z_fixed.cuda())

    #Define networks
    G_bg = Generator_BG(z_dim=opt.noise_size,
                        label_channel=len(category_names),
                        num_res_blocks=opt.num_res_blocks,
                        num_res_blocks_fg=opt.num_res_blocks_fg,
                        num_res_blocks_bg=opt.num_res_blocks_bg)
    D_glob = Discriminator(channels=3 + len(category_names),
                           input_size=opt.img_size)
    G_bg.cuda()
    D_glob.cuda()

    #Load parameters from pre-trained models
    if opt.pre_trained_model_path != None and opt.pre_trained_model_epoch != None:
        try:
            G_bg.load_state_dict(
                torch.load(opt.pre_trained_model_path + 'G_bg_epoch_' +
                           opt.pre_trained_model_epoch))
            D_glob.load_state_dict(
                torch.load(opt.pre_trained_model_path + 'D_glob_epoch_' +
                           opt.pre_trained_model_epoch))
            print('Parameters are loaded!')
        except:
            print('Error: Pre-trained parameters are not loaded!')
            pass

    #Define training loss function - binary cross entropy
    BCE_loss = nn.BCELoss()

    #Define feature matching loss
    criterionVGG = VGGLoss()
    criterionVGG = criterionVGG.cuda()

    #Define optimizer
    G_local_optimizer = optim.Adam(G_bg.parameters(),
                                   lr=opt.lr,
                                   betas=(0.0, 0.9))
    D_local_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                          D_glob.parameters()),
                                   lr=opt.lr,
                                   betas=(0.0, 0.9))

    #Deine learning rate scheduler
    scheduler_G = lr_scheduler.StepLR(G_local_optimizer,
                                      step_size=opt.optim_step_size,
                                      gamma=opt.optim_gamma)
    scheduler_D = lr_scheduler.StepLR(D_local_optimizer,
                                      step_size=opt.optim_step_size,
                                      gamma=opt.optim_gamma)

    #----------------------------TRAIN---------------------------------------
    print('training start!')
    start_time = time.time()

    for epoch in range(opt.train_epoch):
        scheduler_G.step()
        scheduler_D.step()

        D_local_losses = []
        G_local_losses = []

        y_real_ = torch.ones(opt.batch_size)
        y_fake_ = torch.zeros(opt.batch_size)
        y_real_, y_fake_ = Variable(y_real_.cuda()), Variable(y_fake_.cuda())
        epoch_start_time = time.time()

        data_iter = iter(train_loader)
        num_iter = 0
        while num_iter < len(train_loader):
            j = 0
            while j < opt.critic_iter and num_iter < len(train_loader):
                j += 1
                sample_batched = data_iter.next()
                num_iter += 1
                x_ = sample_batched['image']
                y_ = sample_batched['seg_mask']
                y_reduced = torch.sum(y_, 1).view(y_.size(0), 1, y_.size(2),
                                                  y_.size(3))
                y_reduced = torch.clamp(y_reduced, 0, 1)
                y_reduced = Variable(y_reduced.cuda())

                #Update discriminators - D
                #Real examples
                D_glob.zero_grad()
                mini_batch = x_.size()[0]

                if mini_batch != opt.batch_size:
                    y_real_ = torch.ones(mini_batch)
                    y_fake_ = torch.zeros(mini_batch)
                    y_real_, y_fake_ = Variable(y_real_.cuda()), Variable(
                        y_fake_.cuda())

                x_, y_ = Variable(x_.cuda()), Variable(y_.cuda())
                x_d = torch.cat([x_, y_], 1)

                D_result = D_glob(x_d).squeeze()
                D_real_loss = BCE_loss(D_result, y_real_)
                D_real_loss.backward()

                #Fake examples
                z_ = torch.randn((mini_batch, opt.noise_size))
                z_ = Variable(z_.cuda())

                #Generate fake images
                G_result, G_result_bg = G_bg(z_, y_)
                G_result_d = torch.cat([G_result, y_], 1)
                D_result = D_glob(G_result_d.detach()).squeeze()

                D_fake_loss = BCE_loss(D_result, y_fake_)
                D_fake_loss.backward()
                D_local_optimizer.step()
                D_train_loss = D_real_loss + D_fake_loss
                D_local_losses.append(D_train_loss.item())
            #Update generator G
            G_bg.zero_grad()
            D_result = D_glob(G_result_d).squeeze()

            G_train_loss = BCE_loss(D_result, y_real_)

            #Feature matching loss between generated image and corresponding ground truth
            FM_loss = criterionVGG(G_result, x_)

            #Branch-similar loss
            branch_sim_loss = mse_loss(torch.mul(G_result, (1 - y_reduced)),
                                       torch.mul(G_result_bg, (1 - y_reduced)))

            total_loss = G_train_loss + opt.lambda_FM * FM_loss + opt.lambda_branch * branch_sim_loss
            total_loss.backward()
            G_local_optimizer.step()
            G_local_losses.append(G_train_loss.item())

            print('loss_d: %.3f, loss_g: %.3f' %
                  (D_train_loss.item(), G_train_loss.item()))
            if (num_iter % 100) == 0:
                print('%d - %d complete!' % ((epoch + 1), num_iter))
                print(result_folder_name)
        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time
        print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' %
              ((epoch + 1), opt.train_epoch, per_epoch_ptime,
               torch.mean(torch.FloatTensor(D_local_losses)),
               torch.mean(torch.FloatTensor(G_local_losses))))

        #Save images
        G_bg.eval()
        G_result, G_result_bg = G_bg(z_fixed, y_fixed)
        G_bg.train()

        if epoch % 10 == 0:
            for t in range(y_fixed.size()[1]):
                show_result((epoch + 1),
                            y_fixed[:, t:t + 1, :, :],
                            save=True,
                            path=root + result_folder_name + '/' + model +
                            str(epoch + 1) + '_masked.png')

        show_result((epoch + 1),
                    G_result,
                    save=True,
                    path=root + result_folder_name + '/' + model +
                    str(epoch + 1) + '.png')
        show_result((epoch + 1),
                    G_result_bg,
                    save=True,
                    path=root + result_folder_name + '/' + model +
                    str(epoch + 1) + '_bg.png')

        #Save model params
        if opt.save_models and (epoch > 21 and epoch % 10 == 0):
            torch.save(
                G_bg.state_dict(), root + model_folder_name + '/' + model +
                'G_bg_epoch_' + str(epoch) + '.pth')
            torch.save(
                D_glob.state_dict(), root + model_folder_name + '/' + model +
                'D_glob_epoch_' + str(epoch) + '.pth')

    end_time = time.time()
    total_ptime = end_time - start_time
    print("Training finish!... save training results")
    print('Training time: ' + str(total_ptime))
Пример #5
0
def train_gmm(opt, train_loader, model, board):
    model.cuda()
    discriminator = Discriminator()
    discriminator.cuda()

    model.train()
    discriminator.train()

    # criterion
    criterion = nn.BCELoss()
    criterionL1 = nn.L1Loss()
    criterionPSC = PSCLoss()

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))

    #count = 0

    #base_name = os.path.basename(opt.checkpoint)
    # save_dir = os.path.join(opt.result_dir, opt.datamode)
    # if not os.path.exists(save_dir):
    #     os.makedirs(save_dir)
    # warp_cloth_dir = os.path.join(save_dir, 'warp-cloth')
    # if not os.path.exists(warp_cloth_dir):
    #    os.makedirs(warp_cloth_dir)

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        c_names = inputs['c_name']
        im = inputs['image'].cuda()
        im_pose = inputs['pose_image'].cuda()
        im_h = inputs['head'].cuda()
        shape = inputs['shape'].cuda()
        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        im_c = inputs['parse_cloth'].cuda()
        im_g = inputs['grid_image'].cuda()
        blank = inputs['blank'].cuda()

        grid, theta = model(agnostic, c)
        warped_cloth = F.grid_sample(c, grid, padding_mode='border')
        warped_mask = F.grid_sample(cm, grid, padding_mode='zeros')
        warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')

        #if (count < 14222):
        #    save_images(warped_cloth, c_names, warp_cloth_dir)
        #    print(warped_cloth.size()[0])
        #    count+=warped_cloth.size()[0]

        visuals = [[im_h, shape, im_pose], [c, warped_cloth, im_c],
                   [warped_grid, (warped_cloth + im) * 0.5, im]]

        discriminator_train_step(opt.batch_size, discriminator, model,
                                 optimizer_d, criterion, im_c, agnostic, c)
        res, loss, lossL1, lossPSC, lossGAN = generator_train_step(
            opt.batch_size, discriminator, model, optimizer, criterion,
            criterionL1, criterionPSC, blank, im_c, agnostic, c)

        if (step + 1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step + 1)
            board.add_scalar('lossL1', lossL1.item(), step + 1)
            board.add_scalar('lossPSC', lossPSC.item(), step + 1)
            board.add_scalar('lossGAN', lossGAN.item(), step + 1)
            board.add_scalar('loss', loss.item(), step + 1)
            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, lossL1: %4f' %
                  (step + 1, t, lossL1.item()),
                  flush=True)
            print('step: %8d, time: %.3f, lossPSC: %4f' %
                  (step + 1, t, lossPSC.item()),
                  flush=True)
            print('step: %8d, time: %.3f, lossGAN: %4f' %
                  (step + 1, t, lossGAN.item()),
                  flush=True)
            print('step: %8d, time: %.3f, loss: %4f' %
                  (step + 1, t, loss.item()),
                  flush=True)

        if (step + 1) % opt.save_count == 0:
            save_checkpoint(
                model,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_%06d.pth' % (step + 1)))
Пример #6
0
class Trainer:
    def __init__(self, dataset_dir, generator_channels, discriminator_channels, nz, style_depth, lrs, betas, eps,
                 phase_iter, weights_halflife, batch_size, n_cpu, opt_level):
        self.nz = nz
        self.dataloader = Dataloader(dataset_dir, batch_size, phase_iter * 2, n_cpu)

        self.generator = Generator(generator_channels, nz, style_depth).cuda()
        self.generator_ema = Generator(generator_channels, nz, style_depth).cuda()
        self.generator_ema.load_state_dict(copy.deepcopy(self.generator.state_dict()))
        self.discriminator = Discriminator(discriminator_channels).cuda()

        self.tb = tensorboard.tf_recorder('StyleGAN')

        self.phase_iter = phase_iter
        self.lrs = lrs
        self.betas = betas
        self.weights_halflife = weights_halflife

        self.opt_level = opt_level

        self.ema = None

        torch.backends.cuda.benchmark = True

    def generator_trainloop(self, batch_size, alpha):
        requires_grad(self.generator, True)
        requires_grad(self.discriminator, False)

        # mixing regularization
        if random.random() < 0.9:
            z = [torch.randn(batch_size, self.nz).cuda(),
                 torch.randn(batch_size, self.nz).cuda()]
        else:
            z = torch.randn(batch_size, self.nz).cuda()

        fake = self.generator(z, alpha=alpha)
        d_fake = self.discriminator(fake, alpha=alpha)
        loss = F.softplus(-d_fake).mean()

        self.optimizer_g.zero_grad()
        with amp.scale_loss(loss, self.optimizer_g) as scaled_loss:
            scaled_loss.backward()
        self.optimizer_g.step()
        for name, param in self.generator.named_parameters():
            if param.requires_grad:
                self.ema(name, param.data)

        return loss.item()

    def discriminator_trainloop(self, real, alpha):
        requires_grad(self.generator, False)
        requires_grad(self.discriminator, True)

        real.requires_grad = True
        self.optimizer_d.zero_grad()

        d_real = self.discriminator(real, alpha=alpha)
        loss_real = F.softplus(-d_real).mean()
        with amp.scale_loss(loss_real, self.optimizer_d) as scaled_loss_real:
            scaled_loss_real.backward(retain_graph=True)

        grad_real = grad(
            outputs=d_real.sum(), inputs=real, create_graph=True
        )[0]
        grad_penalty = (
                grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
        ).mean()
        grad_penalty = 10 / 2 * grad_penalty
        with amp.scale_loss(grad_penalty, self.optimizer_d) as scaled_grad_penalty:
            scaled_grad_penalty.backward()

        if random.random() < 0.9:
            z = [torch.randn(real.size(0), self.nz).cuda(),
                 torch.randn(real.size(0), self.nz).cuda()]
        else:
            z = torch.randn(real.size(0), self.nz).cuda()

        fake = self.generator(z, alpha=alpha)
        d_fake = self.discriminator(fake, alpha=alpha)
        loss_fake = F.softplus(d_fake).mean()
        with amp.scale_loss(loss_fake, self.optimizer_d) as scaled_loss_fake:
            scaled_loss_fake.backward()

        loss = scaled_loss_real + scaled_loss_fake + scaled_grad_penalty

        self.optimizer_d.step()

        return loss.item(), (d_real.mean().item(), d_fake.mean().item())

    def run(self, log_iter, checkpoint):
        global_iter = 0

        test_z = torch.randn(4, self.nz).cuda()

        self.ema = self.init_ema()
        if checkpoint:
            self.load_checkpoint(checkpoint)
        else:
            self.grow()

        while True:
            print('train {}X{} images...'.format(self.dataloader.img_size, self.dataloader.img_size))

            for iter, ((data, _), n_trained_samples) in enumerate(tqdm(self.dataloader), 1):
                real = data.cuda()
                alpha = min(1, n_trained_samples / self.phase_iter) if self.dataloader.img_size > 8 else 1

                loss_d, (real_score, fake_score) = self.discriminator_trainloop(real, alpha)
                loss_g = self.generator_trainloop(real.size(0), alpha)

                if global_iter % log_iter == 0:
                    self.save_ema()
                    self.log(loss_d, loss_g, real_score, fake_score, test_z, alpha)

                # save 3 times during training
                if iter % (len(self.dataloader) // 4 + 1) == 0:
                    self.save_ema()
                    self.save_checkpoint(n_trained_samples)

                global_iter += 1
                self.tb.iter(data.size(0))
            self.save_ema()
            self.save_checkpoint()
            self.grow()

    def save_ema(self):
        self.ema.set_weights(self.generator_ema)

    def init_ema(self):
        ema = EMA()
        for name, param in self.generator.named_parameters():
            if param.requires_grad:
                ema.register(name, param.data)

        return ema

    def log(self, loss_d, loss_g, real_score, fake_score, test_z, alpha):
        with torch.no_grad():
            fake = self.generator(test_z, alpha=alpha)
            fake = (fake + 1) * 0.5
            fake = torch.clamp(fake, min=0.0, max=1.0)

            self.generator.cpu()
            self.generator_ema.cuda()
            fake_ema = self.generator_ema(test_z, alpha=alpha)
            fake_ema = (fake_ema + 1) * 0.5
            fake_ema = torch.clamp(fake_ema, min=0.0, max=1.0)
            self.generator_ema.cpu()
            self.generator.cuda()

        self.tb.add_scalar('loss_d', loss_d)
        self.tb.add_scalar('loss_g', loss_g)
        self.tb.add_scalar('real_score', real_score)
        self.tb.add_scalar('fake_score', fake_score)
        self.tb.add_images('fake', fake)
        self.tb.add_images('fake_ema', fake_ema)

    def grow(self):
        self.discriminator.grow()
        self.generator.grow()
        self.generator_ema.grow()
        self.dataloader.grow()

        self.generator.cuda()
        self.discriminator.cuda()

        decay = 0.0
        if self.weights_halflife > 0:
            decay = 0.5 ** (float(self.dataloader.batch_size) / self.weights_halflife)

        self.ema.grow(self.generator, decay)

        self.tb.renew('{}x{}'.format(self.dataloader.img_size, self.dataloader.img_size))

        self.lr = self.lrs.get(str(self.dataloader.img_size), 0.001)
        self.style_lr = self.lr * 0.01

        self.optimizer_d = optim.Adam(params=self.discriminator.parameters(), lr=self.lr, betas=self.betas)
        self.optimizer_g = optim.Adam([
            {'params': self.generator.model.parameters(), 'lr': self.lr},
            {'params': self.generator.style_mapper.parameters(), 'lr': self.style_lr},
        ],
            betas=self.betas
        )

        [self.generator, self.discriminator], [self.optimizer_g, self.optimizer_d] = amp.initialize(
            [self.generator, self.discriminator],
            [self.optimizer_g, self.optimizer_d],
            opt_level=self.opt_level
        )

    def save_checkpoint(self, tick='last'):
        torch.save({
            'generator': self.generator.state_dict(),
            'generator_ema': self.generator_ema.state_dict(),
            'discriminator': self.discriminator.state_dict(),
            'generator_optimizer': self.optimizer_g.state_dict(),
            'discriminator_optimizer': self.optimizer_d.state_dict(),
            'img_size': self.dataloader.img_size,
            'tick': tick,
        }, 'checkpoints/{}x{}_{}.pth'.format(self.dataloader.img_size, self.dataloader.img_size, tick))

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)

        print('load {}x{} checkpoint'.format(checkpoint['img_size'], checkpoint['img_size']))
        while self.dataloader.img_size < checkpoint['img_size']:
            self.grow()

        self.generator.load_state_dict(checkpoint['generator'])
        self.generator_ema.load_state_dict(checkpoint['generator_ema'])
        self.discriminator.load_state_dict(checkpoint['discriminator'])
        self.optimizer_g.load_state_dict(checkpoint['generator_optimizer'])
        self.optimizer_d.load_state_dict(checkpoint['discriminator_optimizer'])

        if checkpoint['tick'] == 'last':
            self.grow()
        else:
            self.dataloader.set_checkpoint(checkpoint['tick'])
            self.tb.iter(checkpoint['tick'])
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
        discriminator.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.discriminator_save)))

    """
    variable definition
    """
    real_domain_labels = 1
    fake_domain_labels = 0

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)

    domain_labels = torch.LongTensor(FLAGS.batch_size)
    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)

    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()

    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        domain_labels = domain_labels.cuda()
        style_latent_space = style_latent_space.cuda()

    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    discriminator_optimizer = optim.Adam(
        list(discriminator.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    generator_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write('Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\t')
            log.write('Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n')

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))

    # initialise variables
    discriminator_accuracy = 0.

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print('Epoch #' + str(epoch) + '..........................................................................')

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_1 = encoder(Variable(X_1))
            style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

            kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            _, __, class_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(style_1, class_1)
            reconstructed_X_2 = decoder(style_1, class_2)

            reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_1))
            reconstruction_error_2.backward()

            reconstruction_error = reconstruction_error_1 + reconstruction_error_2
            kl_divergence_error = kl_divergence_loss_1

            auto_encoder_optimizer.step()

            # B. run the generator
            for i in range(FLAGS.generator_times):

                generator_optimizer.zero_grad()

                image_batch_1, _, __ = next(loader)
                image_batch_3, _, __ = next(loader)

                domain_labels.fill_(real_domain_labels)
                X_1.copy_(image_batch_1)
                X_3.copy_(image_batch_3)

                style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
                style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

                kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
                kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
                kl_divergence_loss_1.backward(retain_graph=True)

                _, __, class_3 = encoder(Variable(X_3))
                reconstructed_X_1_3 = decoder(style_1, class_3)

                output_1 = discriminator(Variable(X_3), reconstructed_X_1_3)

                generator_error_1 = cross_entropy_loss(output_1, Variable(domain_labels))
                generator_error_1.backward(retain_graph=True)

                style_latent_space.normal_(0., 1.)
                reconstructed_X_latent_3 = decoder(Variable(style_latent_space), class_3)

                output_2 = discriminator(Variable(X_3), reconstructed_X_latent_3)

                generator_error_2 = cross_entropy_loss(output_2, Variable(domain_labels))
                generator_error_2.backward()

                generator_error = generator_error_1 + generator_error_2
                kl_divergence_error += kl_divergence_loss_1

                generator_optimizer.step()

            # C. run the discriminator
            for i in range(FLAGS.discriminator_times):

                discriminator_optimizer.zero_grad()

                # train discriminator on real data
                domain_labels.fill_(real_domain_labels)

                image_batch_1, _, __ = next(loader)
                image_batch_2, image_batch_3, _ = next(loader)

                X_1.copy_(image_batch_1)
                X_2.copy_(image_batch_2)
                X_3.copy_(image_batch_3)

                real_output = discriminator(Variable(X_2), Variable(X_3))

                discriminator_real_error = cross_entropy_loss(real_output, Variable(domain_labels))
                discriminator_real_error.backward()

                # train discriminator on fake data
                domain_labels.fill_(fake_domain_labels)

                style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
                style_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)

                _, __, class_3 = encoder(Variable(X_3))
                reconstructed_X_1_3 = decoder(style_1, class_3)

                fake_output = discriminator(Variable(X_3), reconstructed_X_1_3)

                discriminator_fake_error = cross_entropy_loss(fake_output, Variable(domain_labels))
                discriminator_fake_error.backward()

                # total discriminator error
                discriminator_error = discriminator_real_error + discriminator_fake_error

                # calculate discriminator accuracy for this step
                target_true_labels = torch.cat((torch.ones(FLAGS.batch_size), torch.zeros(FLAGS.batch_size)), dim=0)
                if FLAGS.cuda:
                    target_true_labels = target_true_labels.cuda()

                discriminator_predictions = torch.cat((real_output, fake_output), dim=0)
                _, discriminator_predictions = torch.max(discriminator_predictions, 1)

                discriminator_accuracy = (discriminator_predictions.data == target_true_labels.long()
                                          ).sum().item() / (FLAGS.batch_size * 2)

                if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy:
                    discriminator_optimizer.step()

            if (iteration + 1) % 50 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0]))

                print('')
                print('Generator loss: ' + str(generator_error.data.storage().tolist()[0]))
                print('Discriminator loss: ' + str(discriminator_error.data.storage().tolist()[0]))
                print('Discriminator accuracy: ' + str(discriminator_accuracy))

                print('..........')

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n'.format(
                    epoch,
                    iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    kl_divergence_error.data.storage().tolist()[0],
                    generator_error.data.storage().tolist()[0],
                    discriminator_error.data.storage().tolist()[0],
                    discriminator_accuracy
                ))

            # write to tensorboard
            writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Generator loss', generator_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Discriminator loss', discriminator_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Discriminator accuracy', discriminator_accuracy * 100,
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))
            torch.save(discriminator.state_dict(), os.path.join('checkpoints', FLAGS.discriminator_save))
Пример #8
0
                                              num_workers=opt.n_workers)

    print(len(data_loader))

    G = Generator(opt)
    D = Discriminator(opt)

    G.apply(weight_init)
    D.apply(weight_init)

    print(G)
    print(D)

    if USE_CUDA:
        G = G.cuda()
        D = D.cuda()

    G_optim = torch.optim.Adam(G.parameters(),
                               lr=opt.lr,
                               betas=(opt.beta1, opt.beta2))
    D_optim = torch.optim.Adam(D.parameters(),
                               lr=opt.lr,
                               betas=(opt.beta1, opt.beta2))

    GAN_loss = nn.BCELoss()
    L1_loss = nn.L1Loss()

    total_step = 0
    for epoch in range(opt.n_epoch):
        epoch += 1
        for i, (input, real) in enumerate(data_loader):
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim)
    encoder.apply(weights_init)

    decoder = Decoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
        discriminator.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.discriminator_save)))
    """
    variable definition
    """
    real_domain_labels = 1
    fake_domain_labels = 0

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)

    domain_labels = torch.LongTensor(FLAGS.batch_size)
    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        domain_labels = domain_labels.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))

    discriminator_optimizer = optim.Adam(list(discriminator.parameters()),
                                         lr=FLAGS.initial_learning_rate,
                                         betas=(FLAGS.beta_1, FLAGS.beta_2))

    generator_optimizer = optim.Adam(list(encoder.parameters()) +
                                     list(decoder.parameters()),
                                     lr=FLAGS.initial_learning_rate,
                                     betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write('Epoch\tIteration\tReconstruction_loss\t')
            log.write(
                'Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n')

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist',
                                download=True,
                                train=True,
                                transform=transform_config)
    loader = cycle(
        DataLoader(paired_mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialise variables
    discriminator_accuracy = 0.

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, labels_batch_1 = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            nv_1, nc_1 = encoder(Variable(X_1))
            nv_2, nc_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(nv_1, nc_2)
            reconstructed_X_2 = decoder(nv_2, nc_1)

            reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = reconstruction_error_1 + reconstruction_error_2

            if FLAGS.train_auto_encoder:
                auto_encoder_optimizer.step()

            # B. run the adversarial part of the architecture

            # B. a) run the discriminator
            for i in range(FLAGS.discriminator_times):
                discriminator_optimizer.zero_grad()

                # train discriminator on real data
                domain_labels.fill_(real_domain_labels)

                image_batch_1, image_batch_2, labels_batch_1 = next(loader)

                X_1.copy_(image_batch_1)
                X_2.copy_(image_batch_2)

                real_output = discriminator(Variable(X_1), Variable(X_2))

                discriminator_real_error = FLAGS.disc_coef * cross_entropy_loss(
                    real_output, Variable(domain_labels))
                discriminator_real_error.backward()

                # train discriminator on fake data
                domain_labels.fill_(fake_domain_labels)

                image_batch_3, _, labels_batch_3 = next(loader)
                X_3.copy_(image_batch_3)

                nv_3, nc_3 = encoder(Variable(X_3))

                # reconstruction is taking common factor from X_1 and varying factor from X_3
                reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1])

                fake_output = discriminator(Variable(X_1), reconstructed_X_3_1)

                discriminator_fake_error = FLAGS.disc_coef * cross_entropy_loss(
                    fake_output, Variable(domain_labels))
                discriminator_fake_error.backward()

                # total discriminator error
                discriminator_error = discriminator_real_error + discriminator_fake_error

                # calculate discriminator accuracy for this step
                target_true_labels = torch.cat((torch.ones(
                    FLAGS.batch_size), torch.zeros(FLAGS.batch_size)),
                                               dim=0)
                if FLAGS.cuda:
                    target_true_labels = target_true_labels.cuda()

                discriminator_predictions = torch.cat(
                    (real_output, fake_output), dim=0)
                _, discriminator_predictions = torch.max(
                    discriminator_predictions, 1)

                discriminator_accuracy = (discriminator_predictions.data
                                          == target_true_labels.long()).sum(
                                          ).item() / (FLAGS.batch_size * 2)

                if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy and FLAGS.train_discriminator:
                    discriminator_optimizer.step()

            # B. b) run the generator
            for i in range(FLAGS.generator_times):

                generator_optimizer.zero_grad()

                image_batch_1, _, labels_batch_1 = next(loader)
                image_batch_3, __, labels_batch_3 = next(loader)

                domain_labels.fill_(real_domain_labels)
                X_1.copy_(image_batch_1)
                X_3.copy_(image_batch_3)

                nv_3, nc_3 = encoder(Variable(X_3))

                # reconstruction is taking common factor from X_1 and varying factor from X_3
                reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1])

                output = discriminator(Variable(X_1), reconstructed_X_3_1)

                generator_error = FLAGS.gen_coef * cross_entropy_loss(
                    output, Variable(domain_labels))
                generator_error.backward()

                if FLAGS.train_generator:
                    generator_optimizer.step()

            # print progress after 10 iterations
            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('Generator loss: ' +
                      str(generator_error.data.storage().tolist()[0]))

                print('')
                print('Discriminator loss: ' +
                      str(discriminator_error.data.storage().tolist()[0]))
                print('Discriminator accuracy: ' + str(discriminator_accuracy))

                print('..........')

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    generator_error.data.storage().tolist()[0],
                    discriminator_error.data.storage().tolist()[0],
                    discriminator_accuracy))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Generator loss',
                generator_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Discriminator loss',
                discriminator_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
            torch.save(discriminator.state_dict(),
                       os.path.join('checkpoints', FLAGS.discriminator_save))
            """
            save reconstructed images and style swapped image generations to check progress
            """
            image_batch_1, image_batch_2, labels_batch_1 = next(loader)
            image_batch_3, _, __ = next(loader)

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)
            X_3.copy_(image_batch_3)

            nv_1, nc_1 = encoder(Variable(X_1))
            nv_2, nc_2 = encoder(Variable(X_2))
            nv_3, nc_3 = encoder(Variable(X_3))

            reconstructed_X_1 = decoder(nv_1, nc_2)
            reconstructed_X_3_2 = decoder(nv_3, nc_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            image_batch = np.concatenate(
                (image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(
                reconstructed_X_1.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_x = np.concatenate(
                (reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x,
                        name=str(epoch) + '_target',
                        save=True)

            # save cross reconstructed batch
            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            style_batch = np.concatenate(
                (style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            reconstructed_style = np.transpose(
                reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_style = np.concatenate(
                (reconstructed_style, reconstructed_style,
                 reconstructed_style),
                axis=3)
            imshow_grid(reconstructed_style,
                        name=str(epoch) + '_style_target',
                        save=True)
Пример #10
0
    dataset = datasets.MNIST('.', transform=transform, download=True)
    dataloader = data.DataLoader(dataset, batch_size=4)

    # model
    g = Generator()
    d = Discriminator()

    # losses
    gan_loss = GANLoss()

    # use
    is_cuda = torch.cuda.is_available()
    if is_cuda:
        g = g.cuda()
        d = d.cuda()

    # optimizer
    optim_G = optim.Adam(g.parameters())
    optim_D = optim.Adam(d.parameters())

    # train
    for epoch in range(num_epoch):
        total_batch = len(dataloader)

        for idx, (image, label) in enumerate(dataloader):
            d.train()
            g.train()

            # fake image 생성
Пример #11
0
print(y_fixed)
x_fixed = sample_batched['seg_mask'][0:num_test_img]
x_fixed = torch.squeeze(x_fixed)
x_fixed = torch.sum(x_fixed,dim=1)
x_fixed = x_fixed.view(x_fixed.size()[0],1,x_fixed.size()[1],x_fixed.size()[2])
x_fixed = Variable(x_fixed.cuda())


z_fixed = torch.randn((num_test_img, noise_size))
z_fixed= Variable(z_fixed.cuda())
    
#--------------------------Define Networks------------------------------------
G_local = Generator_Baseline_2(z_dim=noise_size, label_channel=len(category_names),num_res_blocks=num_res_blocks)
D_local = Discriminator(channels=1+len(category_names), input_size=img_size)
G_local.cuda()
D_local.cuda()

#Load parameters from pre-trained model
if  load_params:
    G_local.load_state_dict(torch.load('result_mask/models_local_'+str(log_numb)+'/coco_model_G_glob_epoch_'+str(epoch_bias)+'.pth'))
    D_local.load_state_dict(torch.load('result_mask/models_local_'+str(log_numb)+'/coco_model_D_glob_epoch_'+str(epoch_bias)+'.pth'))
    print('Parameters are loaded from logFile: models_local_' +str(log_numb) +' ---- Epoch: '+str(epoch_bias))


# Binary Cross Entropy loss
if use_LSGAN_loss:
    BCE_loss= nn.MSELoss()
else:
    BCE_loss = nn.BCELoss()