예제 #1
0
class CycleGANModel(nn.Module):
    def __init__(self,
                 num_iter=100,
                 num_iter_decay=100,
                 lambda_A=10,
                 lambda_B=10,
                 lambda_identity=0.5):
        super(CycleGANModel, self).__init__()
        self.name = None

        self.epoch_count = torch.tensor(1)  ###
        self.num_iter = torch.tensor(num_iter)
        self.num_iter_decay = torch.tensor(num_iter_decay)

        self.lambda_A = torch.tensor(lambda_A)
        self.lambda_B = torch.tensor(lambda_B)
        self.lambda_identity = torch.tensor(lambda_identity)

        self.netG_A = define_G(num_res_blocks=9)
        self.netG_B = define_G(num_res_blocks=9)

        self.netD_A = define_D()
        self.netD_B = define_D()

        self.fake_A_pool = ImagePool(pool_size=50)
        self.fake_B_pool = ImagePool(pool_size=50)

        self.criterionGAN = define_GAN_loss()
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()

        self.optimizer_G_A = optim.Adam(self.netG_A.parameters(),
                                        lr=0.0002,
                                        betas=(0.5, 0.999))
        self.optimizer_G_B = optim.Adam(self.netG_B.parameters(),
                                        lr=0.0002,
                                        betas=(0.5, 0.999))
        self.optimizer_D_A = optim.Adam(self.netD_A.parameters(),
                                        lr=0.0002,
                                        betas=(0.5, 0.999))
        self.optimizer_D_B = optim.Adam(self.netD_B.parameters(),
                                        lr=0.0002,
                                        betas=(0.5, 0.999))

        lambda_rule = lambda epoch: 1.0 - max(
            0, epoch + self.epoch_count - self.num_iter) / float(
                self.num_iter_decay + 1)

        self.scheduler_G_A = scheduler.LambdaLR(self.optimizer_G_A,
                                                lr_lambda=lambda_rule)
        self.scheduler_G_B = scheduler.LambdaLR(self.optimizer_G_B,
                                                lr_lambda=lambda_rule)
        self.scheduler_D_A = scheduler.LambdaLR(self.optimizer_D_A,
                                                lr_lambda=lambda_rule)
        self.scheduler_D_B = scheduler.LambdaLR(self.optimizer_D_B,
                                                lr_lambda=lambda_rule)

    def set_input(self, batch_A, batch_B):
        self.real_A = batch_A
        self.real_B = batch_B

    def forward(self):
        self.fake_B = self.netG_A(self.real_A)
        self.rec_A = self.netG_B(self.fake_B)
        self.fake_A = self.netG_B(self.real_B)
        self.rec_B = self.netG_A(self.fake_A)

    def save_images(self, iter_count, batch_size):
        path = "./datasets/night2day/test_results/test_results_" + str(
            model_num) + "/"

        for i in range(batch_size):
            img_num = (iter_count) * batch_size + i

            fake_A_numpy = self.fake_A[i].data.cpu().numpy()
            real_A_numpy = self.real_A[i].data.cpu().numpy()
            rec_A_numpy = self.rec_A[i].data.cpu().numpy()
            fake_B_numpy = self.fake_B[i].data.cpu().numpy()
            real_B_numpy = self.real_B[i].data.cpu().numpy()
            rec_B_numpy = self.rec_B[i].data.cpu().numpy()

            image = np.concatenate((fake_A_numpy, real_A_numpy, rec_A_numpy,
                                    fake_B_numpy, real_B_numpy, rec_B_numpy),
                                   2)  # 2?

            save_image(torch.from_numpy(image).squeeze() / 2 + 0.5,
                       path + self.name + "_" + str(img_num) + '.png',
                       nrow=batch_size)

    def backward_D_basic(self, netD, real, fake):
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)

        pred_fake = netD(fake.detach())  # !
        loss_D_fake = self.criterionGAN(pred_fake, False)

        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()

        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.lambda_identity
        lambda_A = self.lambda_A
        lambda_B = self.lambda_B

        if lambda_idt > 0:
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt

            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B

        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def set_requires_grad(self, nets, requires_grad=False):
        for net in nets:
            for param in net.parameters():
                param.requires_grad = requires_grad

    def optimize_parameters(self):
        self.forward()

        self.set_requires_grad([self.netD_A, self.netD_B], False)
        self.optimizer_G_A.zero_grad()
        self.optimizer_G_B.zero_grad()
        self.backward_G()
        self.optimizer_G_A.step()
        self.optimizer_G_B.step()

        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D_A.zero_grad()
        self.optimizer_D_B.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D_A.step()
        self.optimizer_D_B.step()

        self.loss_D = self.loss_D_A + self.loss_D_B

    def update_learning_rates(self):
        self.scheduler_G_A.step()
        self.scheduler_G_B.step()
        self.scheduler_D_A.step()
        self.scheduler_D_B.step()

    def get_current_losses(self):
        return self.loss_G.item(), self.loss_D.item()
예제 #2
0
class Network(torch.nn.Module):
    def __init__(self,
                 n_input_channels=3,
                 n_output_channels=1,
                 n_blocks=9,
                 initial_filters=64,
                 dropout_value=0.25,
                 lr=1e-3,
                 decay=0,
                 decay_epochs=0,
                 batch_size=1,
                 image_width=640,
                 image_height=640,
                 load_network=False,
                 load_epoch=0,
                 model_path='',
                 name='',
                 gpu_ids=[],
                 gan=False,
                 pool_size=50,
                 lambda_gan=1,
                 n_blocks_discr=3):
        super(Network, self).__init__()
        self.input_nc = n_input_channels
        self.output_nc = n_output_channels
        self.n_blocks = n_blocks
        self.initial_filters = initial_filters
        self.dropout_value = dropout_value
        self.lr = lr
        self.gpu_ids = gpu_ids
        self.batch_size = batch_size
        self.image_width = image_width
        self.image_height = image_height
        self.generator = torch.nn.Module()
        self.discriminator = torch.nn.Module()
        self.decay = decay
        self.decay_epochs = decay_epochs
        self.save_dir = model_path
        os.makedirs(self.save_dir, exist_ok=True)

        self.input_img = None
        self.input_gt = None
        self.var_img = None
        self.var_gt = None
        self.fake_mask = None
        self.dont_care_mask = None

        self.criterion_seg = None
        self.criterion_gan = None
        self.optimizer_seg = None
        self.optimizer_dis = None
        self.fake_mask_pool = None

        self.loss = None
        self.loss_seg = None
        self.loss_g = None
        self.loss_g_gan = None
        self.loss_d_gan = None
        self.gan = gan
        self.pool_size = pool_size
        self.lambda_gan = lambda_gan
        self.n_blocks_discr = n_blocks_discr

        self.load_network = load_network
        self.name = name
        self.load_epoch = load_epoch

        if len(gpu_ids):
            self.tensor = torch.cuda.FloatTensor
        else:
            self.tensor = torch.FloatTensor

        self.initialize(n_input_channels, n_output_channels, n_blocks,
                        initial_filters, dropout_value, lr, batch_size,
                        image_width, image_height, gpu_ids, gan, pool_size,
                        n_blocks_discr)

    def cuda(self):
        self.generator.cuda()

    def initialize(self, n_input_channels, n_output_channels, n_blocks,
                   initial_filters, dropout_value, lr, batch_size, image_width,
                   image_height, gpu_ids, gan, pool_size, n_blocks_discr):

        self.input_img = self.tensor(batch_size, n_input_channels,
                                     image_height, image_width)
        self.input_gt = self.tensor(batch_size, n_output_channels,
                                    image_height, image_width)

        self.generator = uNet(n_input_channels, n_output_channels, n_blocks,
                              initial_filters, dropout_value, gpu_ids)

        if gan:
            self.discriminator = ImageDiscriminatorConv(
                n_output_channels,
                initial_filters,
                dropout_value,
                gpu_ids=gpu_ids,
                n_blocks=n_blocks_discr)
            self.criterion_gan = GANLoss(tensor=self.tensor)
            self.optimizer_dis = torch.optim.Adam(
                self.discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
            self.fake_mask_pool = ImagePool(pool_size)

        if self.load_network:
            self._load_network(self.generator, 'Model', self.load_epoch)
            if gan:
                self._load_network(self.discriminator, 'Discriminator',
                                   self.load_epoch)

        self.criterion_seg = BinarySelectiveCrossEntropyLoss()
        self.optimizer_seg = torch.optim.Adam(self.generator.parameters(),
                                              lr=lr,
                                              betas=(0.5, 0.999))

        print('---------- Network initialized -------------')
        self.print_network(self.generator)
        if gan:
            self.print_network(self.discriminator)
        print('-----------------------------------------------')

    def set_input(self, input_img, input_gt=None):

        if input_img is not None:
            self.input_img.resize_(input_img.size()).copy_(input_img)

        if input_gt is not None:
            self.input_gt.resize_(input_gt.size()).copy_(input_gt)

    def forward(self, vol=False):
        """
        Function to create autograd variables of inputs (necessary for back-propagation)
        :param vol: True if no backprop is needed
        :return:
        """
        self.var_img = torch.autograd.Variable(self.input_img, volatile=vol)
        self.var_gt = torch.autograd.Variable(self.input_gt, volatile=vol)

    def predict(self):
        """
        Function to predict from datasets
        :return: fakeB: generated image from dataset A to look like images in dataset B
        :return: recA: reconstructed image from fakeB
        :return: fakeA: generated image from dataset B to look like images in dataset A
        :return: recB: reconstructed image from fakeA
        """
        assert (self.input_img is not None)

        self.var_img = torch.autograd.Variable(self.input_img, volatile=True)
        self.fake_mask = self.generator.forward(self.var_img)

        return self.fake_mask

    def backward_seg(self):
        self.fake_mask = self.generator.forward(self.var_img)

        self.loss_seg = self.criterion_seg(self.fake_mask, self.var_gt)

        self.loss_g = self.loss_seg

        if self.gan:
            pred_fake = self.discriminator.forward(self.fake_mask)
            self.loss_g_gan = self.criterion_gan(pred_fake, True)
            self.loss_g = self.loss_seg + self.loss_g_gan * self.lambda_gan

        self.loss_g.backward()

    def backward_d(self):
        fake_mask = self.fake_mask_pool.query(self.fake_mask)
        pred_real = self.discriminator.forward(self.var_gt)
        loss_d_real = self.criterion_gan(input_tensor=pred_real,
                                         target_is_real=True)
        pred_fake = self.discriminator.forward(fake_mask.detach())
        loss_d_fake = self.criterion_gan(input_tensor=pred_fake,
                                         target_is_real=False)

        loss_d = (loss_d_real + loss_d_fake) * 0.5
        loss_d.backward()
        self.loss_d_gan = loss_d

    def optimize(self):
        """
        Function for parameter optimization
        :return: None
        """

        self.forward()

        self.optimizer_seg.zero_grad()
        self.backward_seg()
        self.optimizer_seg.step()

        if self.gan:
            self.optimizer_dis.zero_grad()
            self.backward_d()
            self.optimizer_dis.step()

    def get_current_errors(self):
        """
        Function to get access to current errors outside class
        :return: OrderedDict with values different models
        """

        errors = [self.loss_seg.data[0]]
        labels = ["Seg"]

        if self.gan:
            errors.append(self.loss_d_gan.data[0])
            errors.append(self.loss_g_gan.data[0])
            errors.append(self.loss_g.data[0])
            labels.append("Discr")
            labels.append("Seg_GAN")
            labels.append("Seg_total")
        tuple_list = list(zip(labels, errors))

        return OrderedDict(tuple_list)

    def save(self, label):
        """
        Function to save the subnets
        :param label: label (part of the file the subnet will be saved to)
        :return: None
        """
        self._save_network(self.generator, 'Model', label, self.gpu_ids)
        if self.gan:
            self._save_network(self.discriminator, 'Discriminator', label,
                               self.gpu_ids)

    def _save_network(self, network, network_label, epoch_label, gpu_ids):
        """
                Helper Function for saving pytorch networks (can be used in subclasses)
                :param network: the network to save
                :param network_label: the network label (name)
                :param epoch_label: the epoch to save
                :param gpu_ids: the gpu ids to continue training after saving
                :return: None
                """

        save_filename = str(
            self.name) + '_%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if len(gpu_ids) and torch.cuda.is_available():
            network.cuda(device_id=gpu_ids[0])

    def _load_network(self, network, network_label, epoch_label):
        """
        Helper Function for loading pytorch networks (can be used in subclasses)
        :param network: the network variable to store the loaded network in
        :param network_label: part of the filename the network should be loaded from
        :param epoch_label: the epoch to load
        :return: None
        """
        save_filename = str(
            self.name) + '_%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        network.load_state_dict(torch.load(save_path))

    def update_learning_rate(self):
        """
        Function for learning rate scheduling
        :return: None
        """
        tmp = self.lr

        self.lr -= (self.decay / self.decay_epochs)
        # for param_group in self.optimizer_d.param_groups:
        #     param_group['lr'] = self.lr
        for param_group in self.optimizer_seg.param_groups:
            param_group['lr'] = self.lr

        if self.gan:
            for param_group in self.optimizer_dis.param_groups:
                param_group['lr'] = self.lr

        print('update learning rate: %f -> %f' % (tmp, self.lr))

    @staticmethod
    def print_network(network):
        """
        Static Helper Function to print a network summary
        :param network:
        :return: None
        """
        num_params = 0
        for param in network.parameters():
            num_params += param.numel()
        print(network)
        print('Total number of parameters: %d' % num_params)
예제 #3
0
class Pix2Pix(nn.Module):
    def __init__(self, opt):
        super(Pix2Pix, self).__init__()
        self.opt = opt
        self.isTrain = opt.isTrain
        self.Tensor = torch.cuda.FloatTensor if use_gpu else torch.Tensor

        self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize,
                                   opt.fineSize)
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize,
                                   opt.fineSize)

        # Assuming norm_type = batch
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
        # model  of Generator Net is unet_256
        self.GeneratorNet = Generator(opt.input_nc,
                                      opt.output_nc,
                                      8,
                                      opt.ngf,
                                      norm_layer=norm_layer,
                                      use_dropout=not opt.no_dropout)
        if use_gpu:
            self.GeneratorNet.cuda()
        self.GeneratorNet.apply(init_weights)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            # model  of Discriminator Net is basic
            self.DiscriminatorNet = Discriminator(opt.input_nc + opt.output_nc,
                                                  opt.ndf,
                                                  n_layers=3,
                                                  norm_layer=norm_layer,
                                                  use_sigmoid=use_sigmoid)
            if use_gpu:
                self.DiscriminatorNet.cuda()
            self.DiscriminatorNet.apply(init_weights)

        if not self.isTrain or opt.continue_train:
            self.load_network(self.GeneratorNet, 'Generator', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.DiscriminatorNet, 'Discriminator',
                                  opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.learning_rate = opt.lr
            # defining loss functions
            self.criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan,
                                        tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            self.MySchedulers = []  # initialising schedulers
            self.MyOptimizers = []  # initialising optimizers
            self.generator_optimizer = torch.optim.Adam(
                self.GeneratorNet.parameters(),
                lr=self.learning_rate,
                betas=(opt.beta1, 0.999))
            self.discriminator_optimizer = torch.optim.Adam(
                self.DiscriminatorNet.parameters(),
                lr=self.learning_rate,
                betas=(opt.beta1, 0.999))
            self.MyOptimizers.append(self.generator_optimizer)
            self.MyOptimizers.append(self.discriminator_optimizer)

            def lambda_rule(epoch):
                lr_l = 1.0 - max(
                    0, epoch - opt.niter) / float(opt.niter_decay + 1)
                return lr_l

            for optimizer in self.MyOptimizers:
                self.MySchedulers.append(
                    lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule))
                # assuming opt.lr_policy == 'lambda'

        print('<============ NETWORKS INITIATED ============>')
        print_net(self.GeneratorNet)
        if self.isTrain:
            print_net(self.DiscriminatorNet)
        print('<=============================================>')

    def save_network(self, network, network_label, epoch_label):
        save_path = "./saved_models/%s_net_%s.pth" % (epoch_label,
                                                      network_label)
        torch.save(network.cpu().state_dict(), save_path)
        if use_gpu:
            network.cuda()

    def load_network(self, network, network_label, epoch_label):
        save_path = "./saved_models/%s_net_%s.pth" % (epoch_label,
                                                      network_label)
        # torch.save(network.cpu().state_dict(), save_path)
        network.load_state_dict(torch.load(save_path))

    def update_learning_rate(self):
        for scheduler in self.MySchedulers:
            scheduler.step()
        lr = self.MyOptimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % lr)

    def set_input(self, input):
        self.input = input
        if self.opt.which_direction == 'AtoB':
            input_A = input['A']
            input_B = input['B']
            self.image_paths = input['A_paths']
        else:
            input_A = input['B']
            input_B = input['A']
            self.image_paths = input['B_paths']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.generated_B = self.GeneratorNet.forward(self.real_A)
        self.real_B = Variable(self.input_B)

    def get_image_paths(self):
        return self.image_paths

    def backward_Discriminator(self):
        # fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.generated_B), 1))
        fake_AB = self.fake_AB_pool.query(
            torch.cat((self.real_A, self.generated_B), 1))
        self.prediction_fake = self.DiscriminatorNet.forward(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(self.prediction_fake, False)

        real_AB = torch.cat((self.real_A, self.real_B), 1)
        self.prediction_real = self.DiscriminatorNet.forward(real_AB)
        self.loss_D_real = self.criterionGAN(self.prediction_real, False)

        self.loss_Discriminator = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_Discriminator.backward()

    def backward_Generator(self):

        fake_AB = torch.cat((self.real_A, self.generated_B), 1)
        prediction_fake = self.DiscriminatorNet.forward(fake_AB)
        self.loss_G_GAN = self.criterionGAN(prediction_fake, True)

        self.loss_G_L1 = self.criterionL1(self.generated_B,
                                          self.real_B) * self.opt.lambda_A

        self.loss_Generator = self.loss_G_GAN + self.loss_G_L1
        self.loss_Generator.backward()

    def optimize_parameters(self):
        self.forward()

        self.discriminator_optimizer.zero_grad()
        self.backward_Discriminator()
        self.discriminator_optimizer.step()

        self.generator_optimizer.zero_grad()
        self.backward_Generator()
        self.generator_optimizer.step()

    def get_current_errors(self):
        return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
                            ('G_L1', self.loss_G_L1.data[0]),
                            ('D_real', self.loss_D_real.data[0]),
                            ('D_fake', self.loss_D_fake.data[0])])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.generated_B.data)
        real_B = util.tensor2im(self.real_B.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                            ('real_B', real_B)])

    def save(self, label):
        self.save_network(self.GeneratorNet, 'Generator', label)
        self.save_network(self.DiscriminatorNet, 'Discriminator', label)
예제 #4
0
def train(dataset, start_epoch, max_epochs, lr_d, lr_g, batch_size, lmda_cyc,
          lmda_idt, pool_size, context):
    mx.random.seed(int(time.time()))

    print("Loading dataset...", flush=True)
    training_set_a = load_dataset(dataset, "trainA")
    training_set_b = load_dataset(dataset, "trainB")

    gen_ab = ResnetGenerator()
    dis_b = PatchDiscriminator()
    gen_ba = ResnetGenerator()
    dis_a = PatchDiscriminator()
    bce_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
    l1_loss = mx.gluon.loss.L1Loss()

    gen_ab_params_file = "model/{}.gen_ab.params".format(dataset)
    dis_b_params_file = "model/{}.dis_b.params".format(dataset)
    gen_ab_state_file = "model/{}.gen_ab.state".format(dataset)
    dis_b_state_file = "model/{}.dis_b.state".format(dataset)
    gen_ba_params_file = "model/{}.gen_ba.params".format(dataset)
    dis_a_params_file = "model/{}.dis_a.params".format(dataset)
    gen_ba_state_file = "model/{}.gen_ba.state".format(dataset)
    dis_a_state_file = "model/{}.dis_a.state".format(dataset)

    if os.path.isfile(gen_ab_params_file):
        gen_ab.load_parameters(gen_ab_params_file, ctx=context)
    else:
        gen_ab.initialize(GANInitializer(), ctx=context)

    if os.path.isfile(dis_b_params_file):
        dis_b.load_parameters(dis_b_params_file, ctx=context)
    else:
        dis_b.initialize(GANInitializer(), ctx=context)

    if os.path.isfile(gen_ba_params_file):
        gen_ba.load_parameters(gen_ba_params_file, ctx=context)
    else:
        gen_ba.initialize(GANInitializer(), ctx=context)

    if os.path.isfile(dis_a_params_file):
        dis_a.load_parameters(dis_a_params_file, ctx=context)
    else:
        dis_a.initialize(GANInitializer(), ctx=context)

    print("Learning rate of discriminator:", lr_d, flush=True)
    print("Learning rate of generator:", lr_g, flush=True)
    trainer_gen_ab = mx.gluon.Trainer(gen_ab.collect_params(), "Nadam", {
        "learning_rate": lr_g,
        "beta1": 0.5
    })
    trainer_dis_b = mx.gluon.Trainer(dis_b.collect_params(), "Nadam", {
        "learning_rate": lr_d,
        "beta1": 0.5
    })
    trainer_gen_ba = mx.gluon.Trainer(gen_ba.collect_params(), "Nadam", {
        "learning_rate": lr_g,
        "beta1": 0.5
    })
    trainer_dis_a = mx.gluon.Trainer(dis_a.collect_params(), "Nadam", {
        "learning_rate": lr_d,
        "beta1": 0.5
    })

    if os.path.isfile(gen_ab_state_file):
        trainer_gen_ab.load_states(gen_ab_state_file)

    if os.path.isfile(dis_b_state_file):
        trainer_dis_b.load_states(dis_b_state_file)

    if os.path.isfile(gen_ba_state_file):
        trainer_gen_ba.load_states(gen_ba_state_file)

    if os.path.isfile(dis_a_state_file):
        trainer_dis_a.load_states(dis_a_state_file)

    fake_a_pool = ImagePool(pool_size)
    fake_b_pool = ImagePool(pool_size)

    print("Training...", flush=True)
    for epoch in range(start_epoch, max_epochs):
        ts = time.time()

        random.shuffle(training_set_a)
        random.shuffle(training_set_b)

        training_dis_a_L = 0.0
        training_dis_b_L = 0.0
        training_gen_L = 0.0
        training_batch = 0

        for real_a, real_b in get_batches(training_set_a,
                                          training_set_b,
                                          batch_size,
                                          ctx=context):
            training_batch += 1

            fake_a, _ = gen_ba(real_b)
            fake_b, _ = gen_ab(real_a)

            with mx.autograd.record():
                real_a_y, real_a_cam_y = dis_a(real_a)
                real_a_L = bce_loss(real_a_y,
                                    mx.nd.ones_like(real_a_y, ctx=context))
                real_a_cam_L = bce_loss(
                    real_a_cam_y, mx.nd.ones_like(real_a_cam_y, ctx=context))
                fake_a_y, fake_a_cam_y = dis_a(fake_a_pool.query(fake_a))
                fake_a_L = bce_loss(fake_a_y,
                                    mx.nd.zeros_like(fake_a_y, ctx=context))
                fake_a_cam_L = bce_loss(
                    fake_a_cam_y, mx.nd.zeros_like(fake_a_cam_y, ctx=context))
                L = real_a_L + real_a_cam_L + fake_a_L + fake_a_cam_L
                L.backward()
            trainer_dis_a.step(batch_size)
            dis_a_L = mx.nd.mean(L).asscalar()
            if dis_a_L != dis_a_L:
                raise ValueError()

            with mx.autograd.record():
                real_b_y, real_b_cam_y = dis_b(real_b)
                real_b_L = bce_loss(real_b_y,
                                    mx.nd.ones_like(real_b_y, ctx=context))
                real_b_cam_L = bce_loss(
                    real_b_cam_y, mx.nd.ones_like(real_b_cam_y, ctx=context))
                fake_b_y, fake_b_cam_y = dis_b(fake_b_pool.query(fake_b))
                fake_b_L = bce_loss(fake_b_y,
                                    mx.nd.zeros_like(fake_b_y, ctx=context))
                fake_b_cam_L = bce_loss(
                    fake_b_cam_y, mx.nd.zeros_like(fake_b_cam_y, ctx=context))
                L = real_b_L + real_b_cam_L + fake_b_L + fake_b_cam_L
                L.backward()
            trainer_dis_b.step(batch_size)
            dis_b_L = mx.nd.mean(L).asscalar()
            if dis_b_L != dis_b_L:
                raise ValueError()

            with mx.autograd.record():
                fake_a, gen_a_cam_y = gen_ba(real_b)
                fake_a_y, fake_a_cam_y = dis_a(fake_a)
                gan_a_L = bce_loss(fake_a_y,
                                   mx.nd.ones_like(fake_a_y, ctx=context))
                gan_a_cam_L = bce_loss(
                    fake_a_cam_y, mx.nd.ones_like(fake_a_cam_y, ctx=context))
                rec_b, _ = gen_ab(fake_a)
                cyc_b_L = l1_loss(rec_b, real_b)
                idt_a, idt_a_cam_y = gen_ba(real_a)
                idt_a_L = l1_loss(idt_a, real_a)
                gen_a_cam_L = bce_loss(
                    gen_a_cam_y, mx.nd.ones_like(
                        gen_a_cam_y, ctx=context)) + bce_loss(
                            idt_a_cam_y,
                            mx.nd.zeros_like(idt_a_cam_y, ctx=context))
                gen_ba_L = gan_a_L + gan_a_cam_L + cyc_b_L * lmda_cyc + idt_a_L * lmda_cyc * lmda_idt + gen_a_cam_L
                fake_b, gen_b_cam_y = gen_ab(real_a)
                fake_b_y, fake_b_cam_y = dis_b(fake_b)
                gan_b_L = bce_loss(fake_b_y,
                                   mx.nd.ones_like(fake_b_y, ctx=context))
                gan_b_cam_L = bce_loss(
                    fake_b_cam_y, mx.nd.ones_like(fake_b_cam_y, ctx=context))
                rec_a, _ = gen_ba(fake_b)
                cyc_a_L = l1_loss(rec_a, real_a)
                idt_b, idt_b_cam_y = gen_ab(real_b)
                idt_b_L = l1_loss(idt_b, real_b)
                gen_b_cam_L = bce_loss(
                    gen_b_cam_y, mx.nd.ones_like(
                        gen_b_cam_y, ctx=context)) + bce_loss(
                            idt_b_cam_y,
                            mx.nd.zeros_like(idt_b_cam_y, ctx=context))
                gen_ab_L = gan_b_L + gan_b_cam_L + cyc_a_L * lmda_cyc + idt_b_L * lmda_cyc * lmda_idt + gen_b_cam_L
                L = gen_ba_L + gen_ab_L
                L.backward()
            trainer_gen_ba.step(batch_size)
            trainer_gen_ab.step(batch_size)
            gen_L = mx.nd.mean(L).asscalar()
            if gen_L != gen_L:
                raise ValueError()

            training_dis_a_L += dis_a_L
            training_dis_b_L += dis_b_L
            training_gen_L += gen_L
            print(
                "[Epoch %d  Batch %d]  dis_a_loss %.10f  dis_b_loss %.10f  gen_loss %.10f  elapsed %.2fs"
                % (epoch, training_batch, dis_a_L, dis_b_L, gen_L,
                   time.time() - ts),
                flush=True)

        print(
            "[Epoch %d]  training_dis_a_loss %.10f  training_dis_b_loss %.10f  training_gen_loss %.10f  duration %.2fs"
            % (epoch + 1, training_dis_a_L / training_batch,
               training_dis_b_L / training_batch,
               training_gen_L / training_batch, time.time() - ts),
            flush=True)

        gen_ab.save_parameters(gen_ab_params_file)
        gen_ba.save_parameters(gen_ba_params_file)
        dis_a.save_parameters(dis_a_params_file)
        dis_b.save_parameters(dis_b_params_file)
        trainer_gen_ab.save_states(gen_ab_state_file)
        trainer_gen_ba.save_states(gen_ba_state_file)
        trainer_dis_a.save_states(dis_a_state_file)
        trainer_dis_b.save_states(dis_b_state_file)
예제 #5
0
    def train(self, epochs, batch_size=1, sample_interval=50, pool_size=50):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size, ) + self.disc_patch)
        fake = np.zeros((batch_size, ) + self.disc_patch)

        fake_a_pool = ImagePool(pool_size)
        fake_b_pool = ImagePool(pool_size)

        tensorboard = TensorBoard(batch_size=batch_size, write_grads=True)
        tensorboard.set_model(self.combined)

        def named_logs(model, logs):
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(
                    self.data_loader.load_batch(batch_size)):

                # ----------------------
                #  Train Discriminators
                # ----------------------

                # Translate images to opposite domain
                fake_B = fake_b_pool.query(self.g_AB.predict(imgs_A))
                fake_A = fake_a_pool.query(self.g_BA.predict(imgs_B))

                # Train the discriminators (original images = real / translated = Fake)
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                # Total disciminator loss
                d_loss = 0.5 * np.add(dA_loss, dB_loss)

                # ------------------
                #  Train Generators
                # ------------------

                # Train the generators
                g_loss = self.combined.train_on_batch(
                    [imgs_A, imgs_B],
                    [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B])

                elapsed_time = datetime.datetime.now() - start_time

                # K.clear_session()

                # Plot the progress
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                    % (epoch, epochs,
                       batch_i, self.data_loader.n_batches,
                       d_loss[0], 100 * d_loss[1],
                       g_loss[0],
                       np.mean(g_loss[1:3]),
                       np.mean(g_loss[3:5]),
                       np.mean(g_loss[5:6]),
                       elapsed_time))

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)

            if epoch % 1 == 0:
                self.combined.save_weights(
                    f"saved_model/{self.dataset_name}/{epoch}.h5")
예제 #6
0
class BEGANModel():
    def __init__(self, opt, gpu_ids=[0], continue_run=None):
        self.opt = opt
        self.kt = 0
        self.lamk = 0.001
        self.lambdaImg = 100
        self.lambdaGan = 1.0
        self.model_names = ['netD', 'netG']
        self.gpu_ids = gpu_ids

        if not continue_run:
            expname = '-'.join([
                'b_' + str(self.opt.batchSize), 'ngf_' + str(self.opt.ngf),
                'ndf_' + str(self.opt.ndf), 'gm_' + str(self.opt.gamma)
            ])
            self.rundir = self.opt.rundir + '/pix2pixBEGAN-' + datetime.now(
            ).strftime('%B%d-%H-%M-%S') + expname + self.opt.comment
            if not os.path.isdir(self.rundir):
                os.mkdir(self.rundir)
            with open(self.rundir + '/options.pkl', 'wb') as file:
                pickle.dump(opt, file)
        else:
            self.rundir = continue_run
            if os.path.isfile(self.rundir + '/options.pkl'):
                with open(self.rundir + '/options.pkl', 'rb') as file:
                    tmp = opt.rundir
                    tmp_lr = opt.lr
                    self.opt = pickle.load(file)
                    self.opt.rundir = tmp
                    self.opt.lr = tmp_lr

        self.netG = UnetGenerator(input_nc=3,
                                  output_nc=3,
                                  num_downs=7,
                                  ngf=self.opt.ngf,
                                  norm_layer=nn.BatchNorm2d,
                                  use_dropout=True)
        self.netD = UnetDescriminator(input_nc=3,
                                      output_nc=3,
                                      num_downs=7,
                                      ngf=self.opt.ndf,
                                      norm_layer=nn.BatchNorm2d,
                                      use_dropout=True)

        # Decide which device we want to run on
        self.device = torch.device("cuda:0" if (
            torch.cuda.is_available()) else "cpu")

        init_net(self.netG, 'normal', 0.002, [0])
        init_net(self.netD, 'normal', 0.002, [0])

        self.netG.to(self.device)
        self.netD.to(self.device)
        self.imagePool = ImagePool(pool_size)

        self.criterionL1 = torch.nn.L1Loss()

        if continue_run:
            self.load_networks('latest')

        self.writer = Logger(self.rundir)
        self.start_step, self.opt.lr = self.writer.get_latest(
            'misc/lr', self.opt.lr)

        # initialize optimizers
        self.optimG = torch.optim.Adam(self.netG.parameters(),
                                       lr=self.opt.lr,
                                       betas=(beta1, 0.999))
        self.optimD = torch.optim.Adam(self.netD.parameters(),
                                       lr=self.opt.lr,
                                       betas=(beta1, 0.999))

    def set_input(self, data):
        self.real_A = data['A'].to(self.device)
        self.real_B = data['B'].to(self.device)

    def forward(self):
        self.fake_B = self.netG(self.real_A)

    def backward_D(self):
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimD.zero_grad()
        fake = self.imagePool.query(self.fake_B.detach())

        recon_real_B = self.netD(self.real_B)
        recon_fake = self.netD(fake)

        d_real = torch.mean(torch.abs(recon_real_B - self.real_B))
        d_fake = torch.mean(torch.abs(recon_fake - fake))

        L_D = d_real - self.kt * d_fake
        L_D.backward()
        self.optimD.step()

        self.L_D_val = L_D.item()
        self.d_fake_cpu = d_fake.detach().cpu().item()
        self.d_real_cpu = d_real.detach().cpu().item()
        self.recon_real_B_cpu = recon_real_B.detach().cpu()
        self.recon_fake_cpu = recon_fake.detach().cpu()
        self.fake_cpu = fake.detach().cpu()

    def backward_G(self):
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimG.zero_grad()

        L_Img = self.lambdaImg * self.criterionL1(self.fake_B, self.real_B)
        L_Img.backward(retain_graph=True)

        recon_fake_B = self.netD(self.fake_B)
        self.L_G_fake = self.lambdaGan * torch.mean(
            torch.abs(recon_fake_B - self.fake_B))
        if self.lambdaGan > 0:
            self.L_G_fake.backward()

        self.optimG.step()

        self.L_Img_cpu = L_Img.detach().cpu()
        self.L_G_fake_cpu = self.L_G_fake.detach().cpu()

    def update_K(self):
        balance = self.opt.gamma * self.d_real_cpu - self.d_fake_cpu
        self.kt = min(max(self.kt + self.lamk * balance, 0), 1)
        self.M_global = self.d_real_cpu + np.abs(balance)

    def updatelr(self):
        self.opt.lr = self.opt.lr / 2
        for param_group in self.optimD.param_groups:
            param_group['lr'] = self.opt.lr  # param_group['lr']/2
        for param_group in self.optimG.param_groups:
            param_group['lr'] = self.opt.lr  # param_group['lr']/2

    def log(self, epoch, batchn, n_iter):
        print('Writing summaries....')
        self.writer.scalar_summary('misc/M_global', self.M_global, n_iter)
        self.writer.scalar_summary('misc/kt', self.kt, n_iter)
        self.writer.scalar_summary('misc/lr', self.opt.lr, n_iter)
        self.writer.scalar_summary('loss/L_D', self.L_D_val, n_iter)
        self.writer.scalar_summary('loss/d_real', self.d_real_cpu, n_iter)
        self.writer.scalar_summary('loss/d_fake', self.d_fake_cpu, n_iter)
        self.writer.scalar_summary('loss/L_G', self.L_G_fake_cpu, n_iter)
        self.writer.scalar_summary('loss/L1', self.L_Img_cpu, n_iter)

        test_A = self.test_data['A']
        test_B = self.test_data['B']

        val_A = self.val_data['A']

        with torch.no_grad():
            fake_test_B = self.netG(test_A.to(self.device))
            fake_val_B = self.netG(val_A.to(self.device))

        images = torch.cat([test_A, test_B, fake_test_B.cpu()])
        x = vutils.make_grid(images / 2 + 0.5,
                             normalize=True,
                             scale_each=True,
                             nrow=4)
        self.writer.image_summary('Test/Fixed', [x], n_iter)

        images = torch.cat([
            self.real_A.detach().cpu(),
            self.real_B.cpu(),
            self.fake_B.detach().cpu()
        ])
        x = vutils.make_grid(images / 2 + 0.5,
                             normalize=True,
                             scale_each=True,
                             nrow=4)
        self.writer.image_summary('Test/Last', [x], n_iter)

        images = torch.cat([val_A, fake_val_B.cpu()])
        x = vutils.make_grid(images / 2 + 0.5,
                             normalize=True,
                             scale_each=True,
                             nrow=4)
        self.writer.image_summary('Test/Validation', [x], n_iter)

        images = torch.cat([self.real_B.cpu(), self.recon_real_B_cpu])
        x = vutils.make_grid(images / 2 + 0.5,
                             normalize=True,
                             scale_each=True,
                             nrow=4)
        self.writer.image_summary('Discriminator/Recon_Real', [x], n_iter)

        images = torch.cat([self.fake_cpu, self.recon_fake_cpu])
        x = vutils.make_grid(images / 2 + 0.5,
                             normalize=True,
                             scale_each=True,
                             nrow=4)
        self.writer.image_summary('Discriminator/Recon_Fake', [x], n_iter)

        self.save_networks(epoch)
        for name, param in self.netG.named_parameters():
            if 'bn' in name:
                continue
            self.writer.histo_summary('weight_G/' + name,
                                      param.clone().cpu().data.numpy(), n_iter)
            self.writer.histo_summary('grad_G/' + name,
                                      param.grad.clone().cpu().data.numpy(),
                                      n_iter)

        for name, param in self.netD.named_parameters():
            if 'bn' in name:
                continue
            self.writer.histo_summary('weight_D/' + name,
                                      param.clone().cpu().data.numpy(), n_iter)
            self.writer.histo_summary('grad_D/' + name,
                                      param.grad.clone().cpu().data.numpy(),
                                      n_iter)

    def set_test_input(self, test_data):
        self.test_data = test_data

    def set_val_input(self, val_data):
        self.val_data = val_data

    # save models to the disk
    def save_networks(self, epoch):
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '{}_{}.pth'.format(epoch, name)
                save_path = os.path.join(self.rundir, save_filename)
                net = getattr(self, name)

                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                    torch.save(net.cpu().state_dict(), save_path)
                    net.cuda(self.gpu_ids[0])
                else:
                    torch.save(net.cpu().state_dict(), save_path)
                self.update_link(
                    save_path,
                    os.path.join(self.rundir, 'latest_{}.pth'.format(name)))

    def update_link(self, src, dst):
        shutil.copy2(src, dst)

    # load models from the disk
    def load_networks(self, epoch):
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = '%s_%s.pth' % (epoch, name)
                load_path = os.path.join(self.rundir, load_filename)
                if not os.path.isfile(load_path):
                    return

                net = getattr(self, name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                print('loading the model from %s' % load_path)
                state_dict = torch.load(load_path, map_location=self.device)
                if hasattr(state_dict, '_metadata'):
                    del state_dict._metadata

                net.load_state_dict(state_dict)
예제 #7
0
class GAN(nn.Module):
    def __init__(self,
                 lambda_ABA=settings.lambda_ABA,
                 lambda_BAB=settings.lambda_BAB,
                 lambda_local=settings.lambda_local,
                 pool_size=settings.pool_size,
                 max_crop_side=settings.max_crop_side,
                 decay_start=settings.decay_start,
                 epochs_to_zero_lr=settings.epochs_to_zero_lr,
                 warm_epochs=settings.warmup_epochs):
        super(GAN, self).__init__()

        self.r = 0
        self.lambda_ABA = lambda_ABA
        self.lambda_BAB = lambda_BAB
        self.lambda_local = lambda_local
        self.max_crop_side = max_crop_side

        self.netG_A = Generator(input_nc=4, output_nc=3)
        self.netG_B = Generator(input_nc=4, output_nc=3)
        self.netD_A = NLayerDiscriminator(input_nc=3)
        self.netD_B = NLayerDiscriminator(input_nc=3)
        self.localD = NLayerDiscriminator(input_nc=3)
        self.crop_drones = CropDrones()
        self.criterionGAN = GANLoss("lsgan")
        self.criterionCycle = nn.L1Loss()

        init_weights(self.netG_A)
        init_weights(self.netG_B)
        init_weights(self.netD_A)
        init_weights(self.netD_B)
        init_weights(self.localD)

        self.fake_B_pool = ImagePool(pool_size)
        self.fake_A_pool = ImagePool(pool_size)
        self.fake_drones_pool = ImagePool(pool_size)

    def get_inputs(self, input_):
        self.real_A_with_windows = torch.as_tensor(input_['A'],
                                                   device=self.device)
        self.real_B_with_windows = torch.as_tensor(input_['B'],
                                                   device=self.device)
        self.real_A = self.real_A_with_windows[:, :-1]
        self.real_B = self.real_B_with_windows[:, :-1]
        self.A_windows = self.real_A_with_windows[:, -1:]
        self.B_windows = self.real_B_with_windows[:, -1:]
        self.real_drones = torch.zeros(self.real_B.shape[0],
                                       3,
                                       self.max_crop_side,
                                       self.max_crop_side,
                                       device=self.device)
        self.fake_drones = torch.zeros(self.real_A.shape[0],
                                       3,
                                       self.max_crop_side,
                                       self.max_crop_side,
                                       device=self.device)

    def forward(self, input_):
        self.get_inputs(input_)
        self.fake_A = self.netG_A(self.real_B_with_windows)
        self.rest_B = self.netG_B(
            torch.cat([self.fake_A, self.B_windows], dim=1))
        self.real_drones = self.crop_drones(
            (self.real_B_with_windows, self.real_drones))

        self.fake_B = self.netG_B(self.real_A_with_windows)
        self.rest_A = self.netG_A(
            torch.cat([self.fake_B, self.A_windows], dim=1))
        self.fake_drones = self.crop_drones(
            (torch.cat([self.fake_B, self.A_windows],
                       dim=1), self.fake_drones))

    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()

    def iteration(self, input_):
        self.forward(input_)
        loss_dict = dict()

        # backward for D_A
        real_output_D_A = self.netD_A(self.real_A)
        real_GAN_loss_D_A = self.criterionGAN(real_output_D_A, True)

        fake_A = self.fake_B_pool.query(self.fake_A)
        fake_output_D_A = self.netD_A(fake_A.detach())
        fake_GAN_loss_D_A = self.criterionGAN(fake_output_D_A, False)

        D_A_loss = (real_GAN_loss_D_A + fake_GAN_loss_D_A) * 0.5
        loss_dict['D_A'] = D_A_loss

        # backward for D_B
        real_output_D_B = self.netD_B(self.real_B)
        real_GAN_loss_D_B = self.criterionGAN(real_output_D_B, True)

        fake_B = self.fake_B_pool.query(self.fake_B)
        fake_output_D_B = self.netD_B(fake_B.detach())
        fake_GAN_loss_D_B = self.criterionGAN(fake_output_D_B, False)

        D_B_loss = (real_GAN_loss_D_B + fake_GAN_loss_D_B) * 0.5
        loss_dict['D_B'] = D_B_loss

        # backward for localD
        real_output_localD = self.localD(self.real_drones)
        real_GAN_loss_localD = self.criterionGAN(real_output_localD, True)

        fake_drones = self.fake_drones_pool.query(self.fake_drones)
        fake_output_localD = self.localD(fake_drones.detach())
        fake_GAN_loss_localD = self.criterionGAN(fake_output_localD, False)

        localD_loss = (real_GAN_loss_localD + fake_GAN_loss_localD) * 0.5
        loss_dict['local_D'] = localD_loss

        # backward for G_A and G_B
        G_A_GAN_loss = self.criterionGAN(self.netD_A(self.fake_A), True)
        BAB_cycle_loss = self.criterionCycle(self.real_B, self.rest_B)

        G_B_GAN_loss = self.criterionGAN(self.netD_B(self.fake_B), True)
        G_B_local_loss = self.criterionGAN(self.localD(self.fake_drones), True)
        ABA_cycle_loss = self.criterionCycle(self.real_A, self.rest_A)

        G_loss = G_B_GAN_loss + G_A_GAN_loss + G_B_local_loss * self.lambda_local + ABA_cycle_loss *\
                 self.lambda_ABA * self.r + BAB_cycle_loss * self.lambda_BAB * self.r

        loss_dict['G_B'] = G_B_GAN_loss
        loss_dict['G_A'] = G_A_GAN_loss
        loss_dict['G_local'] = G_B_local_loss
        loss_dict['G'] = G_loss
        return loss_dict
예제 #8
0
class Pix2PixModel(object):
    def __init__(
        self, name="experiment", phase="train", which_epoch="latest",
        batch_size=1, image_size=128, map_nc=1, input_nc=3, output_nc=3,
        num_downs=7, ngf=64, ndf=64, norm_layer="batch", pool_size=50,
        lr=0.0002, beta1=0.5, lambda_D=0.5, lambda_MSE=10,
        lambda_P=5.0, use_dropout=True, gpu_ids=[], n_layers=3,
        use_sigmoid=False, use_lsgan=True, upsampling="nearest",
        continue_train=False, checkpoints_dir="checkpoints/"
    ):
        # Define input data that will be consumed by networks
        self.input_A = torch.FloatTensor(
            batch_size, 3, image_size, image_size
        )
        self.input_map = torch.FloatTensor(
            batch_size, map_nc, image_size, image_size
        )
        norm_layer = nn.BatchNorm2d \
            if norm_layer == "batch" else nn.InstanceNorm2d

        # Define netD and netG
        self.netG = networks.UnetGenerator(
            input_nc=input_nc, output_nc=map_nc,
            num_downs=num_downs, ngf=ngf,
            use_dropout=use_dropout, gpu_ids=gpu_ids, norm_layer=norm_layer,
            upsampling_layer=upsampling
        )
        self.netD = networks.NLayerDiscriminator(
            input_nc=input_nc + map_nc, ndf=ndf,
            n_layers=n_layers, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids
        )

        # Transfer data to GPU
        if len(gpu_ids) > 0:
            self.input_A = self.input_A.cuda()
            self.input_map = self.input_map.cuda()
            self.netD.cuda()
            self.netG.cuda()

        # Initialize parameters of netD and netG
        self.netG.apply(networks.weights_init)
        self.netD.apply(networks.weights_init)

        # Load trained netD and netG
        if phase == "test" or continue_train:
            netG_checkpoint_file = os.path.join(
                checkpoints_dir, name, "netG_{}.pth".format(which_epoch)
            )
            self.netG.load_state_dict(
                torch.load(netG_checkpoint_file)
            )
            print("Restoring netG from {}".format(netG_checkpoint_file))

        if continue_train:
            netD_checkpoint_file = os.path.join(
                checkpoints_dir, name, "netD_{}.pth".format(which_epoch)
            )
            self.netD.load_state_dict(
                torch.load(netD_checkpoint_file)
            )
            print("Restoring netD from {}".format(netD_checkpoint_file))

        self.name = name
        self.gpu_ids = gpu_ids
        self.checkpoints_dir = checkpoints_dir

        # Criterions
        if phase == "train":
            self.count = 0
            self.lr = lr
            self.lambda_D = lambda_D
            self.lambda_MSE = lambda_MSE

            self.image_pool = ImagePool(pool_size)
            self.criterionGAN = networks.GANLoss(use_lsgan=use_lsgan)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionMSE = torch.nn.MSELoss()  # Landmark loss

            self.optimizer_G = torch.optim.Adam(
                self.netG.parameters(), lr=self.lr, betas=(beta1, 0.999)
            )
            self.optimizer_D = torch.optim.Adam(
                self.netD.parameters(), lr=self.lr, betas=(beta1, 0.999)
            )

            print('---------- Networks initialized -------------')
            networks.print_network(self.netG)
            networks.print_network(self.netD)
            print('-----------------------------------------------')

    def set_input(self, input_A, input_map, input_name):
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_map.resize_(input_map.size()).copy_(input_map)
        self.input_name = input_name

    def get_image_paths(self):
        return self.input_name[0]

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_map = self.netG.forward(self.real_A)
        self.real_map = Variable(self.input_map)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_map = self.netG.forward(self.real_A)
        self.real_map = Variable(self.input_map, volatile=True)

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_Amap = self.image_pool.query(
            torch.cat((self.real_A, self.fake_map), 1)
        )
        self.pred_fake = self.netD.forward(fake_Amap.detach())
        self.loss_D_fake = self.criterionGAN(self.pred_fake, False)

        # Real
        real_Amap = torch.cat((self.real_A, self.real_map), 1)
        self.pred_real = self.netD.forward(real_Amap)
        self.loss_D_real = self.criterionGAN(self.pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * self.lambda_D

        self.loss_D.backward()

    def backward_G(self):
        # Third, G(A)_map = map
        self.loss_G_MSE = self.criterionMSE(
            self.fake_map, self.real_map
        ) * self.lambda_MSE

        fake_Amap = torch.cat(
            (self.real_A, self.fake_map), 1
        )
        pred_fake = self.netD.forward(fake_Amap)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        self.loss_G = self.loss_G_GAN + self.loss_G_MSE
        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        return OrderedDict(
            [
                ('G_GAN', self.loss_G_GAN.data[0]),
                ('G_MSE', self.loss_G_MSE.data[0]),
                ('D_real', self.loss_D_real.data[0]),
                ('D_fake', self.loss_D_fake.data[0])
            ]
        )

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_map = util.tensor2im(self.fake_map.data)
        real_map = util.tensor2im(self.real_map.data)
        return OrderedDict(
            [
                ('real_A', real_A),
                ('fake_map', fake_map),
                ('real_map', real_map)
            ]
        )

    def save(self, which_epoch):
        netD_path = os.path.join(
            self.checkpoints_dir, self.name, "netD_{}.pth".format(which_epoch)
        )
        netG_path = os.path.join(
            self.checkpoints_dir, self.name, "netG_{}.pth".format(which_epoch)
        )
        torch.save(self.netD.cpu().state_dict(), netD_path)
        torch.save(self.netG.cpu().state_dict(), netG_path)

        if len(self.gpu_ids) > 0:
            self.netG.cuda()
            self.netD.cuda()

    def update_learning_rate(self, decay):
        old_lr = self.lr
        self.lr = self.lr * decay
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = self.lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = self.lr
        print('update learning rate: %f -> %f' % (old_lr, self.lr))