Example #1
0
    def __init__(self, args) -> None:
        self.lr = args.learning_rate
        self.LAMBDA = args.LAMBDA
        self.save = args.save
        self.batch_size = args.batch_size
        self.path = args.path
        self.n_epochs = args.epoch_num
        self.eval_interval = 10
        self.G_image_loss = []
        self.G_GAN_loss = []
        self.G_total_loss = []
        self.D_loss = []
        self.netG = Generator().to("cuda")
        self.netD = Discriminator().to("cuda")
        self.optimizerG = flow.optim.Adam(self.netG.parameters(),
                                          lr=self.lr,
                                          betas=(0.5, 0.999))
        self.optimizerD = flow.optim.Adam(self.netD.parameters(),
                                          lr=self.lr,
                                          betas=(0.5, 0.999))
        self.criterionGAN = flow.nn.BCEWithLogitsLoss()
        self.criterionL1 = flow.nn.L1Loss()

        self.checkpoint_path = os.path.join(self.path, "checkpoint")
        self.test_images_path = os.path.join(self.path, "test_images")

        mkdirs(self.checkpoint_path, self.test_images_path)
        self.logger = init_logger(os.path.join(self.path, "log.txt"))
Example #2
0
    def __init__(self, config):
        super(NormalUnet, self).__init__(config)

        self.G = Unet(conv_dim=config.g_conv_dim,
                      n_layers=config.g_layers,
                      max_dim=config.max_conv_dim,
                      im_channels=config.img_channels,
                      skip_connections=config.skip_connections,
                      vgg_like=config.vgg_like)
        self.D = Discriminator(image_size=config.image_size,
                               im_channels=3,
                               attr_dim=1,
                               conv_dim=config.d_conv_dim,
                               n_layers=config.d_layers,
                               max_dim=config.max_conv_dim,
                               fc_dim=config.d_fc_dim)

        print(self.G)
        if self.config.use_image_disc:
            print(self.D)

        self.data_loader = globals()['{}_loader'.format(self.config.dataset)](
            self.config.data_root, self.config.mode, self.config.attrs,
            self.config.crop_size, self.config.image_size,
            self.config.batch_size, self.config.data_augmentation)

        self.logger.info("NormalUnet ready")
Example #3
0
    def __init__(self, alpha, beta, lambda1, lambda2, n_pixels,
                 learning_rate_decay, learning_rate_interval, g_hidden,
                 d_hidden, max_buffer_size):
        self.max_buffer_size = max_buffer_size

        g_input_size = n_pixels
        g_hidden_size = g_hidden
        g_output_size = n_pixels

        d_input_size = g_output_size
        d_hidden_size = d_hidden
        d_output_size = 1

        self.g = GenerativeNetwork(g_hidden_size,
                                   g_output_size,
                                   n_input=g_input_size)
        self.f = GenerativeNetwork(g_hidden_size,
                                   g_output_size,
                                   n_input=g_input_size)
        self.x = Discriminator(d_hidden_size,
                               d_output_size,
                               n_input=d_input_size)
        self.y = Discriminator(d_hidden_size,
                               d_output_size,
                               n_input=d_input_size)

        self.opt_g = optimizers.Adam(alpha=alpha, beta1=beta)
        self.opt_f = optimizers.Adam(alpha=alpha, beta1=beta)
        self.opt_x = optimizers.Adam(alpha=alpha, beta1=beta)
        self.opt_y = optimizers.Adam(alpha=alpha, beta1=beta)
        #self.opt_g = optimizers.SGD(alpha)
        #self.opt_f = optimizers.SGD(alpha)
        #self.opt_x = optimizers.SGD(alpha)
        #self.opt_y = optimizers.SGD(alpha)

        self.opt_g.setup(self.g)
        self.opt_f.setup(self.f)
        self.opt_x.setup(self.x)
        self.opt_y.setup(self.y)

        self.opt_g.use_cleargrads()
        self.opt_f.use_cleargrads()
        self.opt_x.use_cleargrads()
        self.opt_y.use_cleargrads()

        self.n_pixels = n_pixels
        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.learning_rate_decay = learning_rate_decay
        self.learning_rate_interval = learning_rate_interval

        self.buffer = {}
        self.buffer['x'] = np.zeros(
            (self.max_buffer_size, self.n_pixels)).astype('float32')
        self.buffer['y'] = np.zeros(
            (self.max_buffer_size, self.n_pixels)).astype('float32')
Example #4
0
 def load_weight(self, pathlist: dict):
     self.net_Gs = []
     self.net_Ds = []
     for weight in pathlist['net_G']:
         net_G = Generator(self.opt).to(self.device)
         net_G.load_state_dict(torch.load(weight, map_location=self.device))
         self.net_Gs.append(net_G)
     for weight in pathlist['net_D']:
         net_D = Discriminator(self.opt).to(self.device)
         net_D.load_state_dict(torch.load(weight, map_location=self.device))
         self.net_Ds.append(net_D)
Example #5
0
    def __init__(self, opt):
        super(ALI, self).__init__(opt)

        # define input tensors
        self.gpu_ids = opt.gpu_ids
        self.batch_size = opt.batch_size

        self.encoder = VariationalEncoder(gpu_ids=self.gpu_ids,
                                          k=self.opt.z_dimension)
        self.decoder = VariationalDecoder(gpu_ids=self.gpu_ids,
                                          k=self.opt.z_dimension)
        if self.gpu_ids:
            self.encoder.cuda(device=opt.gpu_ids[0])
            self.decoder.cuda(device=opt.gpu_ids[0])
        self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(),
                                                  lr=self.opt.lr,
                                                  betas=(0.5, 1e-3))
        self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(),
                                                  lr=self.opt.lr,
                                                  betas=(0.5, 1e-3))
        self.discriminator = Discriminator(gpu_ids=self.gpu_ids)
        if self.gpu_ids:
            self.discriminator.cuda(device=opt.gpu_ids[0])
        self.discriminator_optimizer = torch.optim.Adam(
            self.discriminator.parameters(), lr=self.opt.lr, betas=(0.5, 1e-3))

        # normal initialization.
        self.encoder.apply(normal_weight_init)
        self.decoder.apply(normal_weight_init)
        self.discriminator.apply(normal_weight_init)

        assert self.decoder.k == self.encoder.k

        # input
        self.input = self.Tensor(opt.batch_size, opt.input_channel, opt.height,
                                 opt.width)
        self.x = None

        self.normal_z = None
        self.sampled_x = None
        self.sampled_z = None
        self.d_sampled_x = None
        self.d_sampled_z = None

        # losses
        self.loss_function = GANLoss(len(self.gpu_ids) > 0)
        self.D_loss = None
        self.G_loss = None
Example #6
0
  def __init__(self, opt):
    super(Pose_GAN, self).__init__()

    # load generator and discriminator models
    # adding extra layers for larger image size
    if(opt.checkMode == 0):
      nfilters_decoder = (512, 512, 512, 256, 128, 3) if max(opt.image_size) < 256 else (512, 512, 512, 512, 256, 128, 3)
      nfilters_encoder = (64, 128, 256, 512, 512, 512) if max(opt.image_size) < 256 else (64, 128, 256, 512, 512, 512, 512)
    else:
      nfilters_decoder = (128, 3) if max(opt.image_size) < 256 else (256, 128, 3)
      nfilters_encoder = (64, 128) if max(opt.image_size) < 256 else (64, 128, 256)

    if (opt.use_input_pose):
      input_nc = 3 + 2*opt.pose_dim
    else:
      input_nc = 3 + opt.pose_dim

    self.num_stacks = opt.num_stacks
    self.batch_size = opt.batch_size
    self.pose_dim = opt.pose_dim
    if(opt.gen_type=='stacked'):
      self.gen = Stacked_Generator(input_nc, opt.num_stacks, opt.pose_dim, nfilters_encoder, nfilters_decoder, use_input_pose=opt.use_input_pose)
    elif(opt.gen_type=='baseline'):
      self.gen = Generator(input_nc, nfilters_encoder, nfilters_decoder, use_input_pose=opt.use_input_pose)
    else:
      raise Exception('Invalid gen_type')
    # discriminator also sees the output image for the target pose
    self.disc = Discriminator(input_nc + 3, use_input_pose=opt.use_input_pose, checkMode=opt.checkMode)
    print('---------- Networks initialized -------------')
    print_network(self.gen)
    print_network(self.disc)
    print('-----------------------------------------------')
    # Setup the optimizers
    lr = opt.learning_rate
    self.disc_opt = torch.optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999))
    self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999))

    # Network weight initialization
    self.gen.cuda()
    self.disc.cuda()
    self.disc.apply(xavier_weights_init)
    self.gen.apply(xavier_weights_init)

    # Setup the loss function for training
    self.ll_loss_criterion = torch.nn.L1Loss()
Example #7
0
    def __init__(self, opt):
        super(ALI, self).__init__(opt)

        # define input tensors
        self.gpu_ids = opt.gpu_ids
        self.batch_size = opt.batch_size

        self.encoder = VariationalEncoder(gpu_ids=self.gpu_ids, k=self.opt.z_dimension)
        self.decoder = VariationalDecoder(gpu_ids=self.gpu_ids, k=self.opt.z_dimension)
        if self.gpu_ids:
            self.encoder.cuda(device=opt.gpu_ids[0])
            self.decoder.cuda(device=opt.gpu_ids[0])
        self.encoder_optimizer = torch.optim.Adam(
            self.encoder.parameters(),
            lr=self.opt.lr,
            betas=(0.5, 1e-3)
        )
        self.decoder_optimizer = torch.optim.Adam(
            self.decoder.parameters(),
            lr=self.opt.lr,
            betas=(0.5, 1e-3)
        )
        self.discriminator = Discriminator(gpu_ids=self.gpu_ids)
        if self.gpu_ids:
            self.discriminator.cuda(device=opt.gpu_ids[0])
        self.discriminator_optimizer = torch.optim.Adam(
            self.discriminator.parameters(),
            lr=self.opt.lr,
            betas=(0.5, 1e-3)
        )

        # normal initialization.
        self.encoder.apply(normal_weight_init)
        self.decoder.apply(normal_weight_init)
        self.discriminator.apply(normal_weight_init)

        assert self.decoder.k == self.encoder.k

        # input
        self.input = self.Tensor(
            opt.batch_size,
            opt.input_channel,
            opt.height,
            opt.width
        )
        self.x = None

        self.normal_z = None
        self.sampled_x = None
        self.sampled_z = None
        self.d_sampled_x = None
        self.d_sampled_z = None

        # losses
        self.loss_function = GANLoss(len(self.gpu_ids) > 0)
        self.D_loss = None
        self.G_loss = None
Example #8
0
    def __init__(self, opt):
        self.opt = opt

        self.genA2B = Generator(opt)
        self.genB2A = Generator(opt)

        if opt.training:
            self.discA = Discriminator(opt)
            self.discB = Discriminator(opt)
            self.learning_rate = tf.contrib.eager.Variable(
                opt.lr, dtype=tf.float32, name='learning_rate')
            self.disc_optim = tf.train.AdamOptimizer(self.learning_rate,
                                                     beta1=opt.beta1)
            self.gen_optim = tf.train.AdamOptimizer(self.learning_rate,
                                                    beta1=opt.beta1)
            self.global_step = tf.train.get_or_create_global_step()
            # Initialize history buffers:
            self.discA_buffer = ImageHistoryBuffer(opt)
            self.discB_buffer = ImageHistoryBuffer(opt)
        # Restore latest checkpoint:
        self.initialize_checkpoint()
        if not opt.training or opt.load_checkpoint:
            self.restore_checkpoint()
Example #9
0
class Pix2Pix:
    def __init__(self, args) -> None:
        self.lr = args.learning_rate
        self.LAMBDA = args.LAMBDA
        self.save = args.save
        self.batch_size = args.batch_size
        self.path = args.path
        self.n_epochs = args.epoch_num
        self.eval_interval = 10
        self.G_image_loss = []
        self.G_GAN_loss = []
        self.G_total_loss = []
        self.D_loss = []
        self.netG = Generator().to("cuda")
        self.netD = Discriminator().to("cuda")
        self.optimizerG = flow.optim.Adam(self.netG.parameters(),
                                          lr=self.lr,
                                          betas=(0.5, 0.999))
        self.optimizerD = flow.optim.Adam(self.netD.parameters(),
                                          lr=self.lr,
                                          betas=(0.5, 0.999))
        self.criterionGAN = flow.nn.BCEWithLogitsLoss()
        self.criterionL1 = flow.nn.L1Loss()

        self.checkpoint_path = os.path.join(self.path, "checkpoint")
        self.test_images_path = os.path.join(self.path, "test_images")

        mkdirs(self.checkpoint_path, self.test_images_path)
        self.logger = init_logger(os.path.join(self.path, "log.txt"))

    def train(self):
        # init dataset
        x, y = load_facades()
        # flow.Tensor() bug in here
        x, y = np.ascontiguousarray(x), np.ascontiguousarray(y)
        self.fixed_inp = to_tensor(x[:self.batch_size].astype(np.float32))
        self.fixed_target = to_tensor(y[:self.batch_size].astype(np.float32))

        batch_num = len(x) // self.batch_size
        label1 = to_tensor(np.ones((self.batch_size, 1, 30, 30)),
                           dtype=flow.float32)
        label0 = to_tensor(np.zeros((self.batch_size, 1, 30, 30)),
                           dtype=flow.float32)

        for epoch_idx in range(self.n_epochs):
            self.netG.train()
            self.netD.train()
            start = time.time()

            # run every epoch to shuffle
            for batch_idx in range(batch_num):
                inp = to_tensor(x[batch_idx * self.batch_size:(batch_idx + 1) *
                                  self.batch_size].astype(np.float32))
                target = to_tensor(
                    y[batch_idx * self.batch_size:(batch_idx + 1) *
                      self.batch_size].astype(np.float32))

                # update D
                d_fake_loss, d_real_loss, d_loss = self.train_discriminator(
                    inp, target, label0, label1)

                # update G
                g_gan_loss, g_image_loss, g_total_loss, g_out = self.train_generator(
                    inp, target, label1)

                self.G_GAN_loss.append(g_gan_loss)
                self.G_image_loss.append(g_image_loss)
                self.G_total_loss.append(g_total_loss)
                self.D_loss.append(d_loss)
                if (batch_idx + 1) % self.eval_interval == 0:
                    self.logger.info(
                        "{}th epoch, {}th batch, d_fakeloss:{:>8.4f}, d_realloss:{:>8.4f},  ggan_loss:{:>8.4f}, gl1_loss:{:>8.4f}"
                        .format(
                            epoch_idx + 1,
                            batch_idx + 1,
                            d_fake_loss,
                            d_real_loss,
                            g_gan_loss,
                            g_image_loss,
                        ))

            self.logger.info("Time for epoch {} is {} sec.".format(
                epoch_idx + 1,
                time.time() - start))

            if (epoch_idx + 1) % 2 * self.eval_interval == 0:
                # save .train() images
                # save .eval() images
                self._eval_generator_and_save_images(epoch_idx)

        if self.save:
            flow.save(
                self.netG.state_dict(),
                os.path.join(self.checkpoint_path,
                             "pix2pix_g_{}".format(epoch_idx + 1)),
            )

            flow.save(
                self.netD.state_dict(),
                os.path.join(self.checkpoint_path,
                             "pix2pix_d_{}".format(epoch_idx + 1)),
            )

            # save train loss and val error to plot
            np.save(
                os.path.join(self.path,
                             "G_image_loss_{}.npy".format(self.n_epochs)),
                self.G_image_loss,
            )
            np.save(
                os.path.join(self.path,
                             "G_GAN_loss_{}.npy".format(self.n_epochs)),
                self.G_GAN_loss,
            )
            np.save(
                os.path.join(self.path,
                             "G_total_loss_{}.npy".format(self.n_epochs)),
                self.G_total_loss,
            )
            np.save(
                os.path.join(self.path, "D_loss_{}.npy".format(self.n_epochs)),
                self.D_loss,
            )
            self.logger.info("*************** Train done ***************** ")

    def train_generator(self, input, target, label1):
        g_out = self.netG(input)
        # First, G(A) should fake the discriminator
        fake_AB = flow.cat([input, g_out], 1)
        pred_fake = self.netD(fake_AB)
        gan_loss = self.criterionGAN(pred_fake, label1)
        # Second, G(A) = B
        l1_loss = self.criterionL1(g_out, target)
        # combine loss and calculate gradients
        g_loss = gan_loss + self.LAMBDA * l1_loss
        g_loss.backward()

        self.optimizerG.step()
        self.optimizerG.zero_grad()
        return (
            to_numpy(gan_loss),
            to_numpy(self.LAMBDA * l1_loss),
            to_numpy(g_loss),
            to_numpy(g_out, False),
        )

    def train_discriminator(self, input, target, label0, label1):
        g_out = self.netG(input)
        # Fake; stop backprop to the generator by detaching fake_B
        fake_AB = flow.cat([input, g_out.detach()], 1)
        pred_fake = self.netD(fake_AB)
        d_fake_loss = self.criterionGAN(pred_fake, label0)

        # Real
        real_AB = flow.cat([input, target], 1)
        pred_real = self.netD(real_AB)
        d_real_loss = self.criterionGAN(pred_real, label1)

        # combine loss and calculate gradients
        d_loss = (d_fake_loss + d_real_loss) * 0.5
        d_loss.backward()
        self.optimizerD.step()
        self.optimizerD.zero_grad()
        return to_numpy(d_fake_loss), to_numpy(d_real_loss), to_numpy(d_loss)

    def _eval_generator_and_save_images(self, epoch_idx):
        results = self._eval_generator()
        save_images(
            results,
            to_numpy(self.fixed_inp, False),
            to_numpy(self.fixed_target, False),
            path=os.path.join(self.test_images_path,
                              "testimage_{:02d}.png".format(epoch_idx + 1)),
        )

    def _eval_generator(self):
        self.netG.eval()
        with flow.no_grad():
            g_out = self.netG(self.fixed_inp)
        return to_numpy(g_out, False)
Example #10
0
class ALI(BaseModel):
    def __init__(self, opt):
        super(ALI, self).__init__(opt)
        
        # define input tensors
        self.gpu_ids = opt.gpu_ids
        self.batch_size = opt.batch_size

        # next lines added by lzh
        self.z = None
        self.infer_z = None
        self.infer_x = None
        self.sampling_count = opt.sampling_count

        self.encoder = VariationalEncoder(gpu_ids=self.gpu_ids, k=self.opt.z_dimension)
        self.decoder = VariationalDecoder(gpu_ids=self.gpu_ids, k=self.opt.z_dimension)

        if self.gpu_ids:
            self.encoder.cuda(device=opt.gpu_ids[0])
            self.decoder.cuda(device=opt.gpu_ids[0])
        self.encoder_optimizer = torch.optim.Adam(
            self.encoder.parameters(),
            lr=self.opt.lr,
            betas=(0.5, 1e-3)
        )
        self.decoder_optimizer = torch.optim.Adam(
            self.decoder.parameters(),
            lr=self.opt.lr,
            betas=(0.5, 1e-3)
        )
        self.discriminator = Discriminator(gpu_ids=self.gpu_ids)
        if self.gpu_ids:
            self.discriminator.cuda(device=opt.gpu_ids[0])
        self.discriminator_optimizer = torch.optim.Adam(
            self.discriminator.parameters(),
            lr=self.opt.lr,
            betas=(0.5, 1e-3)
        )

        # normal initialization.
        self.encoder.apply(normal_weight_init)
        self.decoder.apply(normal_weight_init)
        self.discriminator.apply(normal_weight_init)

        assert self.decoder.k == self.encoder.k

        # input
        self.input = self.Tensor(
            opt.batch_size,
            opt.input_channel,
            opt.height,
            opt.width
        )
        self.x = None

        self.normal_z = None
        self.sampled_x = None
        self.sampled_z = None
        self.d_sampled_x = None
        self.d_sampled_z = None

        # losses
        self.loss_function = GANLoss(len(self.gpu_ids) > 0)
        self.D_loss = None
        self.G_loss = None

    def set_input(self, data, is_z_given=False):
        temp = self.input.clone()
        temp.resize_(self.input.size())
        temp.copy_(self.input)
        self.input = temp
        self.input.resize_(data.size()).copy_(data)
        if not is_z_given:
            self.set_z()

    def set_z(self, var=None, volatile=False):
        if var is None:
            self.normal_z = var
        else:
            if self.gpu_ids:
                self.normal_z = Variable(torch.randn((self.opt.batch_size, self.encoder.k)).cuda(), volatile=volatile)
            else:
                self.normal_z = Variable(torch.randn((self.opt.batch_size, self.encoder.k)), volatile=volatile)

    def forward(self, ic=0, volatile=False):
        # volatile : no back gradient.
        self.x = Variable(self.input, volatile=volatile)
        # Before call self.decoder, normal_z must be set.
        self.sampled_x = self.decoder(self.normal_z)
        self.infer_x = Variable(self.input, volatile=volatile)
        for i in range(ic):
            self.infer_z = self.encoder(self.infer_x)
            self.infer_x = self.decoder(self.infer_z)
        self.sampled_z = self.encoder(self.infer_x)

        if not volatile:
            self.d_sampled_x = self.discriminator(self.x, self.sampled_z)
            self.d_sampled_z = self.discriminator(self.sampled_x, self.normal_z)

    def test(self):
        self.forward(volatile=True)

    def mlpforward(self, iternum, volatile=False):
        # volatile : no back gradient.
        self.x = Variable(self.input, volatile=volatile)
        # Before call self.decoder, normal_z must be set.
        self.sampled_x = self.decoder(self.normal_z)
        self.infer_x = Variable(self.input, volatile=volatile)
        for i in range(iternum):
            self.infer_z = self.encoder(self.infer_x)
            self.infer_x = self.decoder(self.infer_z)
        self.sampled_z = self.encoder(self.infer_x)
        if not volatile:
            self.d_sampled_x = self.discriminator(self.x, self.sampled_z)
            self.d_sampled_z = self.discriminator(self.sampled_x, self.normal_z)

    def forward_encoder(self, var):
        return self.encoder(var)

    def forward_decoder(self, var):
        return self.decoder(var)

    def optimize_parameters(self, inferring_count=0):
        if inferring_count>0:
            self.forward()
            # update discriminator
            self.discriminator_optimizer.zero_grad()
            self.backward_D()
            self.discriminator_optimizer.step()
            # update generator
            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()
            self.backward_G()
            self.encoder_optimizer.step()
            self.decoder_optimizer.step()
            for ic in range(inferring_count):
                self.forward(ic+1)
                # update discriminator
                self.discriminator_optimizer.zero_grad()
                self.backward_D()
                self.discriminator_optimizer.step()
                # update generator
                self.encoder_optimizer.zero_grad()
                self.decoder_optimizer.zero_grad()
                self.backward_G()
                self.encoder_optimizer.step()
                self.decoder_optimizer.step()
        else:
            self.forward()
            # update discriminator
            self.discriminator_optimizer.zero_grad()
            self.backward_D()
            self.discriminator_optimizer.step()
            # update generator
            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()
            self.backward_G()
            self.encoder_optimizer.step()
            self.decoder_optimizer.step()

    def backward_D(self):
        self.D_loss = self.loss_function(
            self.d_sampled_x, 1.
        ) + self.loss_function(
            self.d_sampled_z, 0.
        )
        self.D_loss.backward(retain_graph=True)

    def backward_G(self):
        self.G_loss = self.loss_function(
            self.d_sampled_x, 0.
        ) + self.loss_function(
            self.d_sampled_z, 1.
        )
        self.G_loss.backward(retain_graph=True)

    def get_losses(self):
        return OrderedDict([
            ('D_loss', self.D_loss.cpu().item()),
            ('G_loss', self.G_loss.cpu().item()),
        ])

    def get_visuals(self, sample_single_image=True):
        # Both methods works.
        fake_x = tensor2im(self.sampled_x.data, sample_single_image=sample_single_image)
        real_x = tensor2im(self.x.data, sample_single_image=sample_single_image)
        return OrderedDict([('real_x', real_x), ('fake_x', fake_x)])
        #return self.sampled_x

    # get images
    def get_infervisuals(self, infernum, sample_single_image=True):
        # volatile : no back gradient.
        self.x = Variable(self.input, volatile=True)

        self.infer_x = Variable(self.input, volatile=True)
        for i in range(infernum):
            print(i)
            self.infer_z = self.encoder(self.infer_x)
            self.infer_x = self.decoder(self.infer_z)
        self.sampled_z = self.encoder(self.infer_x)

        # ------------------------------
        infer_x = tensor2im(self.infer_x.data, sample_single_image=sample_single_image)
        real_x = tensor2im(self.x.data, sample_single_image=sample_single_image)
        return OrderedDict([('real_x', real_x), ('infer_x', infer_x)])

    # get the latent variable of the input image
    def get_lv(self):
        return self.sampled_z
        # return OrderedDict([('sample_z', self.sampled_z)])

    def save(self, epoch):
        self.save_network(self.encoder, 'encoder', epoch, self.gpu_ids)
        self.save_network(self.decoder, 'decoder', epoch, self.gpu_ids)
        self.save_network(self.discriminator, 'discriminator', epoch, self.gpu_ids)

    def load(self, epoch):
        self.load_network(self.encoder, 'encoder', epoch)
        self.load_network(self.decoder, 'decoder', epoch)
        self.load_network(self.discriminator, 'discriminator', epoch)

    def remove(self, epoch):
        if epoch == 0:
            return
        self.remove_checkpoint('encoder', epoch)
        self.remove_checkpoint('decoder', epoch)
        self.remove_checkpoint('discriminator', epoch)

    # input x~p(x), get z and output the generate x'=G(z)
    def reconstruction(self, volatile=True):
        # volatile : no back gradient.
        self.x = Variable(self.input, volatile=volatile)
        self.reconstruct_x = Variable(self.input, volatile=volatile)
        for i in range(self.sampling_count):# sampling_count=20
            self.z = self.encoder(self.reconstruct_x)
            self.reconstruct_x = self.decoder(self.z)
            for xx in range(self.batch_size):
                for ii in range(3):
                    for jj in range(32):
                        for kk in range(16):
                            self.reconstruct_x[xx, ii, jj, kk] = self.x[xx, ii, jj, kk]
        reconstruct_x = tensor2im(self.reconstruct_x.data, sample_single_image=False)
        real_x = tensor2im(self.input.data, sample_single_image=False)
        return OrderedDict([('real_x', real_x), ('reconstruct_x', reconstruct_x)])
Example #11
0
    elif args.server_port == 12315:
        data_root = '/hhd/chendaiyuan/Data'
    else:
        data_root = '/data5/chendaiyuan/Data'

    torch.cuda.empty_cache()
    vgg19 = Vgg19().cuda()
    # en2_decoder2 = Style_Reso2_En2_Decoder2(embedding_dim=args.emb_dim, noise_dim=args.noise_dim,
    #                                         hidden_dim=args.hidden_dim).cuda()
    en2_decoder2 = Style_SpatialAttn_Reso2_En2_Decoder2(
        embedding_dim=args.emb_dim,
        noise_dim=args.noise_dim,
        hidden_dim=args.hidden_dim).cuda()
    netD = Discriminator(num_chan=3,
                         hid_dim=args.D_step_dim,
                         sent_dim=1024,
                         emb_dim=args.emb_dim,
                         side_output_at=[64, 128]).cuda()
    print(en2_decoder2)
    print(netD)

    gpus = [a for a in range(len(args.gpus.split(',')))]
    torch.cuda.set_device(gpus[0])
    args.batch_size = args.batch_size * len(gpus)

    import torch.backends.cudnn as cudnn
    cudnn.deterministic = True
    cudnn.benchmark = True

    data_name = args.dataset
    datadir = os.path.join(data_root, data_name)
Example #12
0
class ALI(BaseModel):
    def __init__(self, opt):
        super(ALI, self).__init__(opt)

        # define input tensors
        self.gpu_ids = opt.gpu_ids
        self.batch_size = opt.batch_size

        self.encoder = VariationalEncoder(gpu_ids=self.gpu_ids,
                                          k=self.opt.z_dimension)
        self.decoder = VariationalDecoder(gpu_ids=self.gpu_ids,
                                          k=self.opt.z_dimension)
        if self.gpu_ids:
            self.encoder.cuda(device=opt.gpu_ids[0])
            self.decoder.cuda(device=opt.gpu_ids[0])
        self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(),
                                                  lr=self.opt.lr,
                                                  betas=(0.5, 1e-3))
        self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(),
                                                  lr=self.opt.lr,
                                                  betas=(0.5, 1e-3))
        self.discriminator = Discriminator(gpu_ids=self.gpu_ids)
        if self.gpu_ids:
            self.discriminator.cuda(device=opt.gpu_ids[0])
        self.discriminator_optimizer = torch.optim.Adam(
            self.discriminator.parameters(), lr=self.opt.lr, betas=(0.5, 1e-3))

        # normal initialization.
        self.encoder.apply(normal_weight_init)
        self.decoder.apply(normal_weight_init)
        self.discriminator.apply(normal_weight_init)

        assert self.decoder.k == self.encoder.k

        # input
        self.input = self.Tensor(opt.batch_size, opt.input_channel, opt.height,
                                 opt.width)
        self.x = None

        self.normal_z = None
        self.sampled_x = None
        self.sampled_z = None
        self.d_sampled_x = None
        self.d_sampled_z = None

        # losses
        self.loss_function = GANLoss(len(self.gpu_ids) > 0)
        self.D_loss = None
        self.G_loss = None

    def set_input(self, data, is_z_given=False):
        temp = self.input.clone()
        temp.resize_(self.input.size())
        temp.copy_(self.input)
        self.input = temp
        self.input.resize_(data.size()).copy_(data)
        if not is_z_given:
            self.set_z()

    def set_z(self, var=None, volatile=False):
        if var is None:
            self.normal_z = var
        else:
            if self.gpu_ids:
                self.normal_z = Variable(torch.randn(
                    (self.opt.batch_size, self.encoder.k)).cuda(),
                                         volatile=volatile)
            else:
                self.normal_z = Variable(torch.randn(
                    (self.opt.batch_size, self.encoder.k)),
                                         volatile=volatile)

    def forward(self, volatile=False):
        # volatile : no back gradient.
        self.x = Variable(self.input, volatile=volatile)
        # Before call self.decoder, normal_z must be set.
        self.sampled_x = self.decoder(self.normal_z)
        self.sampled_z = self.encoder(self.x)

        if not volatile:
            self.d_sampled_x = self.discriminator(self.x, self.sampled_z)
            self.d_sampled_z = self.discriminator(self.sampled_x,
                                                  self.normal_z)

    def test(self):
        self.forward(volatile=True)

    def forward_encoder(self, var):
        return self.encoder(var)

    def forward_decoder(self, var):
        return self.decoder(var)

    def optimize_parameters(self):
        self.forward()

        # update discriminator
        self.discriminator_optimizer.zero_grad()
        self.backward_D()
        self.discriminator_optimizer.step()
        # update generator
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        self.backward_G()
        self.encoder_optimizer.step()
        self.decoder_optimizer.step()

    def backward_D(self):
        self.D_loss = self.loss_function(
            self.d_sampled_x, 1.) + self.loss_function(self.d_sampled_z, 0.)
        self.D_loss.backward(retain_graph=True)

    def backward_G(self):
        self.G_loss = self.loss_function(
            self.d_sampled_x, 0.) + self.loss_function(self.d_sampled_z, 1.)
        self.G_loss.backward(retain_graph=True)

    def get_losses(self):
        return OrderedDict([
            ('D_loss', self.D_loss.cpu().data.numpy()[0]),
            ('G_loss', self.G_loss.cpu().data.numpy()[0]),
        ])

    def get_visuals(self, sample_single_image=True):
        fake_x = tensor2im(self.sampled_x.data,
                           sample_single_image=sample_single_image)
        real_x = tensor2im(self.x.data,
                           sample_single_image=sample_single_image)
        return OrderedDict([('real_x', real_x), ('fake_x', fake_x)])

    def save(self, epoch):
        self.save_network(self.encoder, 'encoder', epoch, self.gpu_ids)
        self.save_network(self.decoder, 'decoder', epoch, self.gpu_ids)
        self.save_network(self.discriminator, 'discriminator', epoch,
                          self.gpu_ids)

    def load(self, epoch):
        self.load_network(self.encoder, 'encoder', epoch)
        self.load_network(self.decoder, 'decoder', epoch)
        self.load_network(self.discriminator, 'discriminator', epoch)

    def remove(self, epoch):
        if epoch == 0:
            return
        self.remove_checkpoint('encoder', epoch)
        self.remove_checkpoint('decoder', epoch)
        self.remove_checkpoint('discriminator', epoch)
Example #13
0
def main():
    print(
        '############################### train.py ###############################'
    )

    # Set random seed for reproducibility
    manual_seed = 999
    # manualSeed = random.randint(1, 10000) # use if you want new results
    print("Random Seed: ", manual_seed)
    print()
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)

    # Hyper parameters
    workers = 2
    batch_size = 128
    image_size = 128
    nc = 3
    in_ngc = 3
    out_ngc = 3
    in_ndc = in_ngc + out_ngc
    out_ndc = 1
    ngf = 64
    ndf = 32
    sf = 100  # style factor for generator
    learning_rate = 0.0005
    beta1 = 0.5
    epochs = 100
    gpu = True
    load_saved_model = False

    # print hyper parameters
    print(f'number of workers : {workers}')
    print(f'batch size : {batch_size}')
    print(f'image size : {image_size}')
    print(f'number of channels : {nc}')
    print(f'generator feature map size : {ngf}')
    print(f'discriminator feature map size : {ndf}')
    print(f'style factor : {sf}')
    print(f'learning rate : {learning_rate}')
    print(f'beta1 : {beta1}')
    print(f'epochs: {epochs}')
    print(f'GPU: {gpu}')
    print(f'load saved model: {load_saved_model}')
    print()

    # set up GPU device
    cuda = True if gpu and torch.cuda.is_available() else False

    # load CelebA dataset
    download_path = '/home/pbuddare/EEE_598/data/CelebA'
    # download_path = '/Users/prasanth/Academics/ASU/FALL_2019/EEE_598_CIU/data/Project/CelebA'
    data_loader_src = prepare_celeba_data(download_path, batch_size,
                                          image_size, workers)

    # load respective cartoon dataset
    download_path = '/home/pbuddare/EEE_598/data/Cartoon'
    # download_path = '/Users/prasanth/Academics/ASU/FALL_2019/EEE_598_CIU/data/Project/Cartoon'
    data_loader_tgt = prepare_cartoon_data(download_path, batch_size,
                                           image_size, workers)

    # show sample images
    show_images(
        next(iter(data_loader_src))[0], (8, 8), 16,
        'Training images (Natural)', 'human_real')
    show_images(
        next(iter(data_loader_tgt))[0], (8, 8), 16,
        'Training images (Cartoon)', 'cartoon_real')

    # create generator and discriminator networks
    generator = Generator(in_ngc, out_ngc, ngf)
    discriminator = Discriminator(in_ndc, ndf)
    if cuda:
        generator.cuda()
        discriminator.cuda()
    init_net(generator, gpu_ids=[0])
    init_net(discriminator, gpu_ids=[0])

    # loss function and optimizers
    criterion_GAN = nn.BCEWithLogitsLoss()
    criterion_L1 = nn.L1Loss()
    optimizer_g = optim.Adam(generator.parameters(),
                             lr=learning_rate,
                             betas=(beta1, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(),
                             lr=learning_rate,
                             betas=(beta1, 0.999))

    # Train GAN
    loss_G, loss_D = train_gan(data_loader_src, data_loader_tgt, generator,
                               discriminator, criterion_GAN, criterion_L1,
                               optimizer_d, optimizer_g, sf, batch_size,
                               epochs, cuda)

    # save parameters
    current_time = str(datetime.datetime.now().time()).replace(
        ':', '').replace('.', '') + '.pth'
    g_path = './project_G_' + current_time
    d_path = './project_D_' + current_time
    torch.save(generator.state_dict(), g_path)
    torch.save(discriminator.state_dict(), d_path)

    # generate and display fake images
    test_imgs = next(iter(data_loader_src))[0]
    show_images(test_imgs, (8, 8), 16, 'Testing images (Natural)',
                'human_real_test')
    test_imgs = test_imgs.cuda() if cuda else test_imgs
    fake_imgs = generator(test_imgs).detach()
    show_images(fake_imgs.cpu(), (8, 8), 16, 'Fake images (Cartoon)',
                'cartoon_fake')
Example #14
0
class ALI(BaseModel):
    def __init__(self, opt):
        super(ALI, self).__init__(opt)

        # define input tensors
        self.gpu_ids = opt.gpu_ids
        self.batch_size = opt.batch_size

        self.encoder = VariationalEncoder(gpu_ids=self.gpu_ids, k=self.opt.z_dimension)
        self.decoder = VariationalDecoder(gpu_ids=self.gpu_ids, k=self.opt.z_dimension)
        if self.gpu_ids:
            self.encoder.cuda(device=opt.gpu_ids[0])
            self.decoder.cuda(device=opt.gpu_ids[0])
        self.encoder_optimizer = torch.optim.Adam(
            self.encoder.parameters(),
            lr=self.opt.lr,
            betas=(0.5, 1e-3)
        )
        self.decoder_optimizer = torch.optim.Adam(
            self.decoder.parameters(),
            lr=self.opt.lr,
            betas=(0.5, 1e-3)
        )
        self.discriminator = Discriminator(gpu_ids=self.gpu_ids)
        if self.gpu_ids:
            self.discriminator.cuda(device=opt.gpu_ids[0])
        self.discriminator_optimizer = torch.optim.Adam(
            self.discriminator.parameters(),
            lr=self.opt.lr,
            betas=(0.5, 1e-3)
        )

        # normal initialization.
        self.encoder.apply(normal_weight_init)
        self.decoder.apply(normal_weight_init)
        self.discriminator.apply(normal_weight_init)

        assert self.decoder.k == self.encoder.k

        # input
        self.input = self.Tensor(
            opt.batch_size,
            opt.input_channel,
            opt.height,
            opt.width
        )
        self.x = None

        self.normal_z = None
        self.sampled_x = None
        self.sampled_z = None
        self.d_sampled_x = None
        self.d_sampled_z = None

        # losses
        self.loss_function = GANLoss(len(self.gpu_ids) > 0)
        self.D_loss = None
        self.G_loss = None

    def set_input(self, data, is_z_given=False):
        temp = self.input.clone()
        temp.resize_(self.input.size())
        temp.copy_(self.input)
        self.input = temp
        self.input.resize_(data.size()).copy_(data)
        if not is_z_given:
            self.set_z()

    def set_z(self, var=None, volatile=False):
        if var is None:
            self.normal_z = var
        else:
            if self.gpu_ids:
                self.normal_z = Variable(torch.randn((self.opt.batch_size, self.encoder.k)).cuda(), volatile=volatile)
            else:
                self.normal_z = Variable(torch.randn((self.opt.batch_size, self.encoder.k)), volatile=volatile)

    def forward(self, volatile=False):
        # volatile : no back gradient.
        self.x = Variable(self.input, volatile=volatile)
        # Before call self.decoder, normal_z must be set.
        self.sampled_x = self.decoder(self.normal_z)
        self.sampled_z = self.encoder(self.x)

        if not volatile:
            self.d_sampled_x = self.discriminator(self.x, self.sampled_z)
            self.d_sampled_z = self.discriminator(self.sampled_x, self.normal_z)

    def test(self):
        self.forward(volatile=True)

    def forward_encoder(self, var):
        return self.encoder(var)

    def forward_decoder(self, var):
        return self.decoder(var)

    def optimize_parameters(self):
        self.forward()

        # update discriminator
        self.discriminator_optimizer.zero_grad()
        self.backward_D()
        self.discriminator_optimizer.step()
        # update generator
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        self.backward_G()
        self.encoder_optimizer.step()
        self.decoder_optimizer.step()

    def backward_D(self):
        self.D_loss = self.loss_function(
            self.d_sampled_x, 1.
        ) + self.loss_function(
            self.d_sampled_z, 0.
        )
        self.D_loss.backward(retain_graph=True)

    def backward_G(self):
        self.G_loss = self.loss_function(
            self.d_sampled_x, 0.
        ) + self.loss_function(
            self.d_sampled_z, 1.
        )
        self.G_loss.backward(retain_graph=True)

    def get_losses(self):
        return OrderedDict([
            ('D_loss', self.D_loss.cpu().data.numpy()[0]),
            ('G_loss', self.G_loss.cpu().data.numpy()[0]),
        ])

    def get_visuals(self, sample_single_image=True):
        fake_x = tensor2im(self.sampled_x.data, sample_single_image=sample_single_image)
        real_x = tensor2im(self.x.data, sample_single_image=sample_single_image)
        return OrderedDict([('real_x', real_x), ('fake_x', fake_x)])

    def save(self, epoch):
        self.save_network(self.encoder, 'encoder', epoch, self.gpu_ids)
        self.save_network(self.decoder, 'decoder', epoch, self.gpu_ids)
        self.save_network(self.discriminator, 'discriminator', epoch, self.gpu_ids)

    def load(self, epoch):
        self.load_network(self.encoder, 'encoder', epoch)
        self.load_network(self.decoder, 'decoder', epoch)
        self.load_network(self.discriminator, 'discriminator', epoch)

    def remove(self, epoch):
        if epoch == 0:
            return
        self.remove_checkpoint('encoder', epoch)
        self.remove_checkpoint('decoder', epoch)
        self.remove_checkpoint('discriminator', epoch)
Example #15
0
from torchsummary import summary as torchsummary
from utils import update_lr, sec2time
from utils.data_loader import get_loader
from utils.train_NMD import *
from models.networks import Generator, Discriminator, NMDiscriminator
import time
from tensorboardX import SummaryWriter
import numpy as np

summary = SummaryWriter()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
eps = 1e-7
batch_size = 12
G = Generator().to(device)
D = Discriminator().to(device)
NMD = NMDiscriminator().to(device)

# torchsummary(G, input_size=(3, 112, 112))

# torchsummary(D, input_size=(3, 448, 448))

# torchsummary(NMD, input_size=(3, 448, 448))

NMD.load_state_dict(torch.load('./models/weights/NMD.pth'))

num_epochs = 1000

learning_rateG = 1e-4
learning_rateD = 1e-4
Example #16
0
class NormalUnet(TrainingModule):
    def __init__(self, config):
        super(NormalUnet, self).__init__(config)

        self.G = Unet(conv_dim=config.g_conv_dim,
                      n_layers=config.g_layers,
                      max_dim=config.max_conv_dim,
                      im_channels=config.img_channels,
                      skip_connections=config.skip_connections,
                      vgg_like=config.vgg_like)
        self.D = Discriminator(image_size=config.image_size,
                               im_channels=3,
                               attr_dim=1,
                               conv_dim=config.d_conv_dim,
                               n_layers=config.d_layers,
                               max_dim=config.max_conv_dim,
                               fc_dim=config.d_fc_dim)

        print(self.G)
        if self.config.use_image_disc:
            print(self.D)

        self.data_loader = globals()['{}_loader'.format(self.config.dataset)](
            self.config.data_root, self.config.mode, self.config.attrs,
            self.config.crop_size, self.config.image_size,
            self.config.batch_size, self.config.data_augmentation)

        self.logger.info("NormalUnet ready")

    ################################################################
    ###################### SAVE/lOAD ###############################
    def save_checkpoint(self):
        self.save_one_model(self.G, self.optimizer_G, 'G')
        if self.config.use_image_disc:
            self.save_one_model(self.D, self.optimizer_D, 'D')

    def load_checkpoint(self):
        if self.config.checkpoint is None:
            return

        self.load_one_model(
            self.G, self.optimizer_G if self.config.mode == 'train' else None,
            'G')
        if (self.config.use_image_disc):
            self.load_one_model(
                self.D,
                self.optimizer_D if self.config.mode == 'train' else None, 'D')

        self.current_iteration = self.config.checkpoint

    ################################################################
    ################### OPTIM UTILITIES ############################

    def setup_all_optimizers(self):
        self.optimizer_G = self.build_optimizer(self.G, self.config.g_lr)
        self.optimizer_D = self.build_optimizer(self.D, self.config.d_lr)
        self.load_checkpoint()  #load checkpoint if needed
        self.lr_scheduler_G = self.build_scheduler(self.optimizer_G)
        self.lr_scheduler_D = self.build_scheduler(
            self.optimizer_D, not (self.config.use_image_disc))

    def step_schedulers(self, scalars):
        self.lr_scheduler_G.step()
        self.lr_scheduler_D.step()
        scalars['lr/g_lr'] = self.lr_scheduler_G.get_lr()[0]
        scalars['lr/d_lr'] = self.lr_scheduler_D.get_lr()[0]

    def eval_mode(self):
        self.G.eval()
        self.D.eval()

    def training_mode(self):
        self.G.train()
        self.D.train()

    ################################################################
    ##################### EVAL UTILITIES ###########################
    def log_img_reconstruction(self, img, normals, path=None, writer=False):
        img = img.to(self.device)
        normals = normals.to(self.device)
        normals_hat = self.G(img) * normals[:, 3:]

        x_concat = torch.cat((img[:, :3], normals[:, :3], normals_hat), dim=-1)

        image = tvutils.make_grid(denorm(x_concat), nrow=1)
        if writer:
            self.writer.add_image('sample', image, self.current_iteration)
        if path:
            tvutils.save_image(image, path)

    ########################################################################################
    #####################                 TRAINING               ###########################
    def training_step(self, batch):
        # ================================================================================= #
        #                            1. Preprocess input data                               #
        # ================================================================================= #
        Ia, normals, _, _ = batch

        Ia = Ia.to(self.device)  # input images
        Ia_3ch = Ia[:, :3]
        normals = normals.to(self.device)

        scalars = {}
        # ================================================================================= #
        #                           2. Train the discriminator                              #
        # ================================================================================= #
        if self.config.use_image_disc:
            self.G.eval()
            self.D.train()

            for _ in range(self.config.n_critic):
                # input is the real normal map
                out_disc_real = self.D(normals[:, :3])
                # fake image normals_hat
                normals_hat = self.G(Ia)
                out_disc_fake = self.D(normals_hat.detach())
                #adversarial losses
                d_loss_adv_real = -torch.mean(out_disc_real)
                d_loss_adv_fake = torch.mean(out_disc_fake)
                # compute loss for gradient penalty
                alpha = torch.rand(Ia.size(0), 1, 1, 1).to(self.device)
                x_hat = (alpha * normals[:, :3].data +
                         (1 - alpha) * normals_hat.data).requires_grad_(True)
                out_disc = self.D(x_hat)
                d_loss_adv_gp = self.config.lambda_gp * self.gradient_penalty(
                    out_disc, x_hat)
                #full GAN loss
                d_loss_adv = d_loss_adv_real + d_loss_adv_fake + d_loss_adv_gp
                d_loss = self.config.lambda_adv * d_loss_adv
                scalars['D/loss_adv'] = d_loss.item()
                scalars['D/loss_real'] = d_loss_adv_real.item()
                scalars['D/loss_fake'] = d_loss_adv_fake.item()
                scalars['D/loss_gp'] = d_loss_adv_gp.item()

                # backward and optimize
                self.optimize(self.optimizer_D, d_loss)
                # summarize
                scalars['D/loss'] = d_loss.item()

        # ================================================================================= #
        #                              3. Train the generator                               #
        # ================================================================================= #
        self.G.train()
        self.D.eval()

        normals_hat = self.G(Ia)
        g_loss_rec = self.config.lambda_G_rec * self.angular_reconstruction_loss(
            normals[:, :], normals_hat)
        g_loss = g_loss_rec
        scalars['G/loss_rec'] = g_loss_rec.item()

        if self.config.use_image_disc:
            # original-to-target domain : normals_hat -> GAN + classif
            normals_hat = self.G(Ia)
            out_disc = self.D(normals_hat)
            # GAN loss
            g_loss_adv = -self.config.lambda_adv * torch.mean(out_disc)
            g_loss += g_loss_adv
            scalars['G/loss_adv'] = g_loss_adv.item()

        # backward and optimize
        self.optimize(self.optimizer_G, g_loss)
        # summarize
        scalars['G/loss'] = g_loss.item()

        return scalars

    def validating_step(self, batch):
        Ia, normals, _, _ = batch
        self.log_img_reconstruction(
            Ia,
            normals,
            os.path.join(self.config.sample_dir,
                         'sample_{}.png'.format(self.current_iteration)),
            writer=True)

    def testing_step(self, batch, batch_id):
        i, (Ia, normals, _, _) = batch_id, batch
        self.log_img_reconstruction(Ia,
                                    normals,
                                    os.path.join(
                                        self.config.result_dir,
                                        'sample_{}_{}.png'.format(
                                            i + 1, self.config.checkpoint)),
                                    writer=False)
Example #17
0
def test_output_shape():
    inp = torch.ones([7, 3, 32, 32])
    D = Discriminator(channels=3)
    out = D(inp)
    assert out.shape == torch.Size([7, 1, 1, 1])
Example #18
0
class Pose_GAN(nn.Module):
  def __init__(self, opt):
    super(Pose_GAN, self).__init__()

    # load generator and discriminator models
    # adding extra layers for larger image size
    if(opt.checkMode == 0):
      nfilters_decoder = (512, 512, 512, 256, 128, 3) if max(opt.image_size) < 256 else (512, 512, 512, 512, 256, 128, 3)
      nfilters_encoder = (64, 128, 256, 512, 512, 512) if max(opt.image_size) < 256 else (64, 128, 256, 512, 512, 512, 512)
    else:
      nfilters_decoder = (128, 3) if max(opt.image_size) < 256 else (256, 128, 3)
      nfilters_encoder = (64, 128) if max(opt.image_size) < 256 else (64, 128, 256)

    if (opt.use_input_pose):
      input_nc = 3 + 2*opt.pose_dim
    else:
      input_nc = 3 + opt.pose_dim

    self.num_stacks = opt.num_stacks
    self.batch_size = opt.batch_size
    self.pose_dim = opt.pose_dim
    if(opt.gen_type=='stacked'):
      self.gen = Stacked_Generator(input_nc, opt.num_stacks, opt.pose_dim, nfilters_encoder, nfilters_decoder, use_input_pose=opt.use_input_pose)
    elif(opt.gen_type=='baseline'):
      self.gen = Generator(input_nc, nfilters_encoder, nfilters_decoder, use_input_pose=opt.use_input_pose)
    else:
      raise Exception('Invalid gen_type')
    # discriminator also sees the output image for the target pose
    self.disc = Discriminator(input_nc + 3, use_input_pose=opt.use_input_pose, checkMode=opt.checkMode)
    print('---------- Networks initialized -------------')
    print_network(self.gen)
    print_network(self.disc)
    print('-----------------------------------------------')
    # Setup the optimizers
    lr = opt.learning_rate
    self.disc_opt = torch.optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999))
    self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999))

    # Network weight initialization
    self.gen.cuda()
    self.disc.cuda()
    self.disc.apply(xavier_weights_init)
    self.gen.apply(xavier_weights_init)

    # Setup the loss function for training
    self.ll_loss_criterion = torch.nn.L1Loss()

  # add code for intermediate supervision for the interpolated poses using pretrained pose-estimator
  def gen_update(self, input, target, interpol_pose, opt):
    self.gen.zero_grad()

    if(opt['gen_type']=='stacked'):
      outputs_gen = self.gen(input, interpol_pose)
      out_gen = outputs_gen[-1]
    else:
      out_gen = self.gen(input)
      outputs_gen = []

    inp_img, inp_pose, out_pose = pose_utils.get_imgpose(input, opt['use_input_pose'], opt['pose_dim'])

    inp_dis = torch.cat([inp_img, inp_pose, out_gen, out_pose], dim=1)
    out_dis = self.disc(inp_dis)

    # computing adversarial loss
    for it in range(out_dis.shape[0]):
      out = out_dis[it, :]
      all_ones = Variable(torch.ones((out.size(0))).cuda())
      if it==0:
        # ad_loss = nn.functional.binary_cross_entropy(out, all_ones)
        ad_loss = -torch.mean(torch.log(out + 1e-7))
      else:
        # ad_loss += nn.functional.binary_cross_entropy(out, all_ones)
        ad_loss += -torch.mean(torch.log(out + 1e-7)
                               )
    ll_loss = self.ll_loss_criterion(out_gen, target)
    ad_loss = ad_loss * opt['gan_penalty_weight'] / self.batch_size
    ll_loss = ll_loss * opt['l1_penalty_weight']
    total_loss = ad_loss + ll_loss
    total_loss.backward()
    self.gen_opt.step()
    self.gen_ll_loss = ll_loss.item()
    self.gen_ad_loss = ad_loss.item()
    self.gen_total_loss = total_loss.item()
    return out_gen, outputs_gen, [self.gen_total_loss, self.gen_ll_loss, self.gen_ad_loss ]

  def dis_update(self, input, target, interpol_pose, real_inp, real_target, opt):
    self.disc.zero_grad()

    if (opt['gen_type'] == 'stacked'):
      out_gen = self.gen(input, interpol_pose)
      out_gen = out_gen[-1]
    else:
      out_gen = self.gen(input)

    inp_img, inp_pose, out_pose = pose_utils.get_imgpose(input, opt['use_input_pose'], opt['pose_dim'])

    fake_disc_inp = torch.cat([inp_img, inp_pose, out_gen, out_pose], dim=1)
    r_inp_img, r_inp_pose, r_out_pose = pose_utils.get_imgpose(real_inp, opt['use_input_pose'], opt['pose_dim'])
    real_disc_inp = torch.cat([r_inp_img, r_inp_pose, real_target, r_out_pose], dim=1)
    data_dis = torch.cat((real_disc_inp, fake_disc_inp), 0)
    res_dis = self.disc(data_dis)

    for it in range(res_dis.shape[0]):
      out = res_dis[it,:]
      if(it<opt['batch_size']):
        out_true_n = out.size(0)
        # real inputs should be 1
        # all1 = Variable(torch.ones((out_true_n)).cuda())
        if it == 0:
          # ad_true_loss = nn.functional.binary_cross_entropy(out, all1)
          ad_true_loss = -torch.mean(torch.log(out + 1e-7))
        else:
          # ad_true_loss += nn.functional.binary_cross_entropy(out, all1)
          ad_true_loss += -torch.mean(torch.log(out + 1e-7))
      else:
        out_fake_n = out.size(0)
        # fake inputs should be 0, appear after batch_size iters
        # all0 = Variable(torch.zeros((out_fake_n)).cuda())
        if it == opt['batch_size']:
          ad_fake_loss = -torch.mean(torch.log(1- out + 1e-7))
        else:
          ad_fake_loss += -torch.mean(torch.log(1 - out + 1e-7))

    ad_true_loss = ad_true_loss*opt['gan_penalty_weight']/self.batch_size
    ad_fake_loss = ad_fake_loss*opt['gan_penalty_weight']/self.batch_size
    ad_loss = ad_true_loss + ad_fake_loss
    loss = ad_loss
    loss.backward()
    self.disc_opt.step()
    self.dis_total_loss = loss.item()
    self.dis_true_loss = ad_true_loss.item()
    self.dis_fake_loss = ad_fake_loss.item()
    return [self.dis_total_loss , self.dis_true_loss , self.dis_fake_loss ]

  def resume(self, save_dir):
    last_model_name = pose_utils.get_model_list(save_dir,"gen")
    if last_model_name is None:
      return 1
    self.gen.load_state_dict(torch.load(last_model_name))
    epoch = int(last_model_name[-7:-4])
    print('Resume gen from epoch %d' % epoch)
    last_model_name = pose_utils.get_model_list(save_dir, "dis")
    if last_model_name is None:
      return 1
    epoch = int(last_model_name[-7:-4])
    self.disc.load_state_dict(torch.load(last_model_name))
    print('Resume disc from epoch %d' % epoch)
    return epoch

  def save(self, save_dir, epoch):
    gen_filename = os.path.join(save_dir, 'gen_{0:03d}.pkl'.format(epoch))
    disc_filename = os.path.join(save_dir, 'disc_{0:03d}.pkl'.format(epoch))
    torch.save(self.gen.state_dict(), gen_filename)
    torch.save(self.disc.state_dict(), disc_filename)

  def normalize_image(self, x):
    return x[:,0:3,:,:]
Example #19
0
class DeformablePose_GAN(nn.Module):
    def __init__(self, opt):
        super(DeformablePose_GAN, self).__init__()

        # load generator and discriminator models
        # adding extra layers for larger image size
        nfilters_decoder = (512, 512, 512, 256, 128,
                            3) if max(opt.image_size) < 256 else (512, 512,
                                                                  512, 512,
                                                                  256, 128, 3)
        nfilters_encoder = (64, 128, 256, 512, 512,
                            512) if max(opt.image_size) < 256 else (64, 128,
                                                                    256, 512,
                                                                    512, 512,
                                                                    512)

        if (opt.use_input_pose):
            input_nc = 3 + 2 * opt.pose_dim
        else:
            input_nc = 3 + opt.pose_dim

        self.batch_size = opt.batch_size
        self.num_stacks = opt.num_stacks
        self.pose_dim = opt.pose_dim
        if (opt.gen_type == 'stacked'):
            self.gen = Stacked_Generator(input_nc,
                                         opt.num_stacks,
                                         opt.image_size,
                                         opt.pose_dim,
                                         nfilters_encoder,
                                         nfilters_decoder,
                                         opt.warp_skip,
                                         use_input_pose=opt.use_input_pose)
            # hack to get better results
            pretrained_gen_path = '../exp/' + 'full_' + opt.dataset + '/models/gen_090.pkl'
            self.gen.generator.load_state_dict(torch.load(pretrained_gen_path))
            print("Loaded generator from pretrained model ")
        elif (opt.gen_type == 'baseline'):
            self.gen = Deformable_Generator(input_nc,
                                            self.pose_dim,
                                            opt.image_size,
                                            nfilters_encoder,
                                            nfilters_decoder,
                                            opt.warp_skip,
                                            use_input_pose=opt.use_input_pose)
        else:
            raise Exception('Invalid gen_type')
        # discriminator also sees the output image for the target pose
        self.disc = Discriminator(input_nc + 3,
                                  use_input_pose=opt.use_input_pose)
        # self.disc_2 = Discriminator(6, use_input_pose=opt.use_input_pose)
        pretrained_disc_path = "/home/linlilang/pose-transfer/exp/baseline_market/models/disc_020.pkl"
        print("Loaded discriminator from pretrained model ")
        self.disc.load_state_dict(torch.load(pretrained_disc_path))

        print('---------- Networks initialized -------------')
        # print_network(self.gen)
        # print_network(self.disc)
        print('-----------------------------------------------')
        # Setup the optimizers
        lr = opt.learning_rate
        self.disc_opt = torch.optim.Adam(self.disc.parameters(),
                                         lr=lr,
                                         betas=(0.5, 0.999))
        # self.disc_opt_2 = torch.optim.Adam(self.disc_2.parameters(), lr=lr, betas=(0.5, 0.999))
        self.gen_opt = torch.optim.Adam(self.gen.parameters(),
                                        lr=lr,
                                        betas=(0.5, 0.999))

        self.content_loss_layer = opt.content_loss_layer
        self.nn_loss_area_size = opt.nn_loss_area_size
        if self.content_loss_layer != 'none':
            self.content_model = resnet101(pretrained=True)
            # Setup the loss function for training
        # Network weight initialization
        self.gen.cuda()
        self.disc.cuda()
        # self.disc_2.cuda()
        self._nn_loss_area_size = opt.nn_loss_area_size
        # applying xavier_uniform, equivalent to glorot unigorm, as in Keras Defo GAN
        # skipping as models are pretrained
        # self.disc.apply(xavier_weights_init)
        # self.gen.apply(xavier_weights_init)
        self.ll_loss_criterion = torch.nn.L1Loss()

    # add code for intermediate supervision for the interpolated poses using pretrained pose-estimator
    def gen_update(self, input, target, other_inputs, opt):
        self.gen.zero_grad()

        if (opt['gen_type'] == 'stacked'):
            interpol_pose = other_inputs['interpol_pose']
            interpol_warps = other_inputs['interpol_warps']
            interpol_masks = other_inputs['interpol_masks']
            outputs_gen = self.gen(input, interpol_pose, interpol_warps,
                                   interpol_masks)
            out_gen = outputs_gen[-1]
        else:
            warps = other_inputs['warps']
            masks = other_inputs['masks']
            out_gen, out_gen_2, out_gen_3 = self.gen(input, warps, masks)
            outputs_gen = []

        inp_img, inp_pose, out_pose = pose_utils.get_imgpose(
            input, opt['use_input_pose'], opt['pose_dim'])

        inp_dis = torch.cat([inp_img, inp_pose, out_gen, out_pose], dim=1)
        out_dis = self.disc(inp_dis)

        inp_dis_2 = torch.cat([inp_img, inp_pose, out_gen_2, out_pose], dim=1)
        out_dis_2 = self.disc(inp_dis_2)

        inp_dis_3 = torch.cat([inp_img, inp_pose, out_gen_3, out_pose], dim=1)
        out_dis_3 = self.disc(inp_dis_3)

        # computing adversarial loss
        for it in range(out_dis.shape[0]):
            out = out_dis[it, :]
            all_ones = Variable(torch.ones((out.size(0))).cuda())
            if it == 0:
                # ad_loss = nn.functional.binary_cross_entropy(out, all_ones)
                ad_loss = -torch.mean(torch.log(out + 1e-7))
            else:
                # ad_loss += nn.functional.binary_cross_entropy(out, all_ones)
                ad_loss += -torch.mean(torch.log(out + 1e-7))

        for it in range(out_dis_2.shape[0]):
            out_2 = out_dis_2[it, :]
            all_ones = Variable(torch.ones((out.size(0))).cuda())
            ad_loss += -torch.mean(torch.log(out_2 + 1e-7))

        for it in range(out_dis_3.shape[0]):
            out_3 = out_dis_3[it, :]
            all_ones = Variable(torch.ones((out.size(0))).cuda())
            ad_loss += -torch.mean(torch.log(out_3 + 1e-7))

        if self.content_loss_layer != 'none':
            content_out_gen = pose_utils.Feature_Extractor(
                self.content_model,
                input=out_gen,
                layer_name=self.content_loss_layer)
            content_target = pose_utils.Feature_Extractor(
                self.content_model,
                input=target,
                layer_name=self.content_loss_layer)
            ll_loss = self.nn_loss(content_out_gen, content_target,
                                   self.nn_loss_area_size,
                                   self.nn_loss_area_size)
        else:
            ll_loss = self.ll_loss_criterion(out_gen, target)

        ll_loss += self.ll_loss_criterion(out_gen, target)
        ll_loss += self.ll_loss_criterion(out_gen_2, target)
        ll_loss += self.ll_loss_criterion(out_gen_3, target)

        ad_loss = ad_loss * opt['gan_penalty_weight'] / self.batch_size
        ll_loss = ll_loss * opt['l1_penalty_weight']
        total_loss = ad_loss + ll_loss
        total_loss.backward()
        self.gen_opt.step()
        self.gen_ll_loss = ll_loss.item()
        self.gen_ad_loss = ad_loss.item()
        self.gen_total_loss = total_loss.item()
        return out_gen, outputs_gen, [
            self.gen_total_loss, self.gen_ll_loss, self.gen_ad_loss
        ]

    def dis_update(self, input, target, other_inputs, real_inp, real_target,
                   opt):
        self.disc.zero_grad()

        if (opt['gen_type'] == 'stacked'):
            interpol_pose = other_inputs['interpol_pose']
            interpol_warps = other_inputs['interpol_warps']
            interpol_masks = other_inputs['interpol_masks']
            out_gen = self.gen(input, interpol_pose, interpol_warps,
                               interpol_masks)
            out_gen = out_gen[-1]
        else:
            warps = other_inputs['warps']
            masks = other_inputs['masks']
            out_gen, out_gen_2, out_gen_3 = self.gen(input, warps, masks)

        inp_img, inp_pose, out_pose = pose_utils.get_imgpose(
            input, opt['use_input_pose'], opt['pose_dim'])

        fake_disc_inp = torch.cat([inp_img, inp_pose, out_gen, out_pose],
                                  dim=1)
        r_inp_img, r_inp_pose, r_out_pose = pose_utils.get_imgpose(
            real_inp, opt['use_input_pose'], opt['pose_dim'])
        real_disc_inp = torch.cat(
            [r_inp_img, r_inp_pose, real_target, r_out_pose], dim=1)
        data_dis = torch.cat((real_disc_inp, fake_disc_inp), 0)
        res_dis = self.disc(data_dis)

        fake_disc_inp_2 = torch.cat([inp_img, inp_pose, out_gen_2, out_pose],
                                    dim=1)
        data_dis_2 = torch.cat((real_disc_inp, fake_disc_inp_2), 0)
        res_dis_2 = self.disc(data_dis_2)

        fake_disc_inp_3 = torch.cat([inp_img, inp_pose, out_gen_3, out_pose],
                                    dim=1)
        data_dis_3 = torch.cat((real_disc_inp, fake_disc_inp_3), 0)
        res_dis_3 = self.disc(data_dis_3)

        # print(res_dis.shape)
        for it in range(res_dis.shape[0]):
            out = res_dis[it, :]
            if (it < opt['batch_size']):
                out_true_n = out.size(0)
                # real inputs should be 1
                # all1 = Variable(torch.ones((out_true_n)).cuda())
                if it == 0:
                    # ad_true_loss = nn.functional.binary_cross_entropy(out, all1)
                    ad_true_loss = -torch.mean(torch.log(out + 1e-7))
                else:
                    # ad_true_loss += nn.functional.binary_cross_entropy(out, all1)
                    ad_true_loss += -torch.mean(torch.log(out + 1e-7))
            else:
                out_fake_n = out.size(0)
                # fake inputs should be 0, appear after batch_size iters
                # all0 = Variable(torch.zeros((out_fake_n)).cuda())
                if it == opt['batch_size']:
                    # ad_true_loss = -torch.mean(torch.log(out + 1e-7))= nn.functional.binary_cross_entropy(out, all0)
                    ad_fake_loss = -torch.mean(torch.log(1 - out + 1e-7))
                else:
                    ad_fake_loss += -torch.mean(torch.log(1 - out + 1e-7))

        for it in range(res_dis_2.shape[0]):
            out_2 = res_dis_2[it, :]
            if (it < opt['batch_size']):
                out_true_n_2 = out_2.size(0)
                # real inputs should be 1
                # all1 = Variable(torch.ones((out_true_n)).cuda())
                ad_true_loss += -torch.mean(torch.log(out_2 + 1e-7))
            else:
                out_fake_n_2 = out_2.size(0)
                # fake inputs should be 0, appear after batch_size iters
                # all0 = Variable(torch.zeros((out_fake_n)).cuda())
                ad_fake_loss += -torch.mean(torch.log(1 - out_2 + 1e-7))

        for it in range(res_dis_3.shape[0]):
            out_3 = res_dis_3[it, :]
            if (it < opt['batch_size']):
                out_true_n_3 = out_3.size(0)
                # real inputs should be 1
                # all1 = Variable(torch.ones((out_true_n)).cuda())
                ad_true_loss += -torch.mean(torch.log(out_3 + 1e-7))
            else:
                out_fake_n_3 = out_3.size(0)
                # fake inputs should be 0, appear after batch_size iters
                # all0 = Variable(torch.zeros((out_fake_n)).cuda())
                ad_fake_loss += -torch.mean(torch.log(1 - out_3 + 1e-7))

        ad_true_loss = ad_true_loss * opt[
            'gan_penalty_weight'] / self.batch_size
        ad_fake_loss = ad_fake_loss * opt[
            'gan_penalty_weight'] / self.batch_size
        ad_loss = ad_true_loss + ad_fake_loss
        loss = ad_loss
        loss.backward()
        self.disc_opt.step()
        self.dis_total_loss = loss.item()
        self.dis_true_loss = ad_true_loss.item()
        self.dis_fake_loss = ad_fake_loss.item()
        return [self.dis_total_loss, self.dis_true_loss, self.dis_fake_loss]

    def nn_loss(self, predicted, ground_truth, nh=3, nw=3):
        v_pad = nh // 2
        h_pad = nw // 2
        val_pad = nn.ConstantPad2d((v_pad, v_pad, h_pad, h_pad),
                                   -10000)(ground_truth)

        reference_tensors = []
        for i_begin in range(0, nh):
            i_end = i_begin - nh + 1
            i_end = None if i_end == 0 else i_end
            for j_begin in range(0, nw):
                j_end = j_begin - nw + 1
                j_end = None if j_end == 0 else j_end
                sub_tensor = val_pad[:, :, i_begin:i_end, j_begin:j_end]
                reference_tensors.append(sub_tensor.unsqueeze(-1))
        reference = torch.cat(reference_tensors, dim=-1)
        ground_truth = ground_truth.unsqueeze(dim=-1)

        predicted = predicted.unsqueeze(-1)
        abs = torch.abs(reference - predicted)
        # sum along channels
        norms = torch.sum(abs, dim=1)
        # min over neighbourhood
        loss, _ = torch.min(norms, dim=-1)
        # loss = torch.sum(loss)/self.batch_size
        loss = torch.mean(loss)

        return loss

    def resume(self, save_dir):
        last_model_name = pose_utils.get_model_list(save_dir, "gen")
        if last_model_name is None:
            return 1
        self.gen.load_state_dict(torch.load(last_model_name))
        epoch = int(last_model_name[-7:-4])
        print('Resume gen from epoch %d' % epoch)
        last_model_name = pose_utils.get_model_list(save_dir, "dis")
        if last_model_name is None:
            return 1
        epoch = int(last_model_name[-7:-4])
        self.disc.load_state_dict(torch.load(last_model_name))
        print('Resume disc from epoch %d' % epoch)
        return epoch

    def save(self, save_dir, epoch):
        gen_filename = os.path.join(save_dir, 'gen_{0:03d}.pkl'.format(epoch))
        disc_filename = os.path.join(save_dir,
                                     'disc_{0:03d}.pkl'.format(epoch))
        torch.save(self.gen.state_dict(), gen_filename)
        torch.save(self.disc.state_dict(), disc_filename)

    def normalize_image(self, x):
        return x[:, 0:3, :, :]
Example #20
0
    def __init__(self, config):
        super(FaderNet, self).__init__(config)

        self.norm = 'none'
        self.G = FaderNetGenerator(conv_dim=config.g_conv_dim,
                                   n_layers=config.g_layers,
                                   max_dim=config.max_conv_dim,
                                   im_channels=config.img_channels,
                                   skip_connections=config.skip_connections,
                                   vgg_like=config.vgg_like,
                                   attr_dim=len(config.attrs),
                                   n_attr_deconv=config.n_attr_deconv,
                                   normalization=self.norm,
                                   first_conv=config.first_conv,
                                   n_bottlenecks=config.n_bottlenecks)
        if self.config.GAN_style == 'vanilla':
            self.D = Discriminator(image_size=config.image_size,
                                   im_channels=3,
                                   attr_dim=len(config.attrs),
                                   conv_dim=config.d_conv_dim,
                                   n_layers=config.d_layers,
                                   max_dim=config.max_conv_dim,
                                   fc_dim=config.d_fc_dim,
                                   normalization=self.norm)
        elif self.config.GAN_style == 'matching':
            self.D = DiscriminatorWithMatchingAttr(
                image_size=config.image_size,
                im_channels=3,
                attr_dim=len(config.attrs),
                conv_dim=config.d_conv_dim,
                n_layers=config.d_layers,
                max_dim=config.max_conv_dim,
                fc_dim=config.d_fc_dim,
                normalization=self.norm)
        elif self.config.GAN_style == 'classif':
            self.D = DiscriminatorWithClassifAttr(image_size=config.image_size,
                                                  im_channels=3,
                                                  attr_dim=len(config.attrs),
                                                  conv_dim=config.d_conv_dim,
                                                  n_layers=config.d_layers,
                                                  max_dim=config.max_conv_dim,
                                                  fc_dim=config.d_fc_dim,
                                                  normalization=self.norm)
        self.LD = Latent_Discriminator(
            image_size=config.image_size,
            im_channels=config.img_channels,
            attr_dim=len(config.attrs),
            conv_dim=config.g_conv_dim,
            n_layers=config.g_layers,
            max_dim=config.max_conv_dim,
            fc_dim=config.d_fc_dim,
            skip_connections=config.skip_connections,
            vgg_like=config.vgg_like,
            normalization=self.norm,
            first_conv=config.first_conv)
        print(self.G)
        if self.config.use_image_disc:
            print(self.D)
        if self.config.use_latent_disc:
            print(self.LD)

        # create all the loss functions that we may need for perceptual loss
        self.loss_P = PerceptualLoss().to(self.device)
        self.loss_S = StyleLoss().to(self.device)
        self.vgg16_f = VGG16FeatureExtractor(
            ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_4']).to(self.device)
        if self.config.use_image_disc:
            self.criterionGAN = GANLoss(self.config.gan_mode).to(self.device)

        self.data_loader = globals()['{}_loader'.format(self.config.dataset)](
            self.config.data_root,
            self.config.mode,
            self.config.attrs,
            self.config.crop_size,
            self.config.image_size,
            self.config.batch_size,
            self.config.data_augmentation,
            mask_input_bg=config.mask_input_bg)

        self.logger.info("FaderNet ready")
Example #21
0
    def __init__(self, opt):
        super(DeformablePose_GAN, self).__init__()

        # load generator and discriminator models
        # adding extra layers for larger image size
        nfilters_decoder = (512, 512, 512, 256, 128,
                            3) if max(opt.image_size) < 256 else (512, 512,
                                                                  512, 512,
                                                                  256, 128, 3)
        nfilters_encoder = (64, 128, 256, 512, 512,
                            512) if max(opt.image_size) < 256 else (64, 128,
                                                                    256, 512,
                                                                    512, 512,
                                                                    512)

        if (opt.use_input_pose):
            input_nc = 3 + 2 * opt.pose_dim
        else:
            input_nc = 3 + opt.pose_dim

        self.batch_size = opt.batch_size
        self.num_stacks = opt.num_stacks
        self.pose_dim = opt.pose_dim
        if (opt.gen_type == 'stacked'):
            self.gen = Stacked_Generator(input_nc,
                                         opt.num_stacks,
                                         opt.image_size,
                                         opt.pose_dim,
                                         nfilters_encoder,
                                         nfilters_decoder,
                                         opt.warp_skip,
                                         use_input_pose=opt.use_input_pose)
            # hack to get better results
            pretrained_gen_path = '../exp/' + 'full_' + opt.dataset + '/models/gen_090.pkl'
            self.gen.generator.load_state_dict(torch.load(pretrained_gen_path))
            print("Loaded generator from pretrained model ")
        elif (opt.gen_type == 'baseline'):
            self.gen = Deformable_Generator(input_nc,
                                            self.pose_dim,
                                            opt.image_size,
                                            nfilters_encoder,
                                            nfilters_decoder,
                                            opt.warp_skip,
                                            use_input_pose=opt.use_input_pose)
        else:
            raise Exception('Invalid gen_type')
        # discriminator also sees the output image for the target pose
        self.disc = Discriminator(input_nc + 3,
                                  use_input_pose=opt.use_input_pose)
        # self.disc_2 = Discriminator(6, use_input_pose=opt.use_input_pose)
        pretrained_disc_path = "/home/linlilang/pose-transfer/exp/baseline_market/models/disc_020.pkl"
        print("Loaded discriminator from pretrained model ")
        self.disc.load_state_dict(torch.load(pretrained_disc_path))

        print('---------- Networks initialized -------------')
        # print_network(self.gen)
        # print_network(self.disc)
        print('-----------------------------------------------')
        # Setup the optimizers
        lr = opt.learning_rate
        self.disc_opt = torch.optim.Adam(self.disc.parameters(),
                                         lr=lr,
                                         betas=(0.5, 0.999))
        # self.disc_opt_2 = torch.optim.Adam(self.disc_2.parameters(), lr=lr, betas=(0.5, 0.999))
        self.gen_opt = torch.optim.Adam(self.gen.parameters(),
                                        lr=lr,
                                        betas=(0.5, 0.999))

        self.content_loss_layer = opt.content_loss_layer
        self.nn_loss_area_size = opt.nn_loss_area_size
        if self.content_loss_layer != 'none':
            self.content_model = resnet101(pretrained=True)
            # Setup the loss function for training
        # Network weight initialization
        self.gen.cuda()
        self.disc.cuda()
        # self.disc_2.cuda()
        self._nn_loss_area_size = opt.nn_loss_area_size
        # applying xavier_uniform, equivalent to glorot unigorm, as in Keras Defo GAN
        # skipping as models are pretrained
        # self.disc.apply(xavier_weights_init)
        # self.gen.apply(xavier_weights_init)
        self.ll_loss_criterion = torch.nn.L1Loss()