class CycleGanModel(BaseModel):
    def name(self):
        return 'TrainCycleGanModel'

    def initialize(self, args):
        BaseModel.initialize(self, args)
        self.input_A = self.Tensor(args.batchSize, 3, 1024, 256)
        self.input_B = self.Tensor(args.batchSize, 3, 1024, 256)

        self.fake_A_Buffer = ReplayBuffer()
        self.fake_B_Buffer = ReplayBuffer()

        self.netG_AtoB = networks.define_G(3, 3, 64, 'resnet_9blocks',
                                           'instance', False, args.init_type,
                                           self.gpu_ids)
        self.netG_BtoA = networks.define_G(3, 3, 64, 'resnet_9blocks',
                                           'instance', False, args.init_type,
                                           self.gpu_ids)
        self.netD_A = networks.define_D(3,
                                        64,
                                        'basic',
                                        norm='instance',
                                        use_sigmoid=False,
                                        gpu_ids=args.gpu_ids)
        self.netD_B = networks.define_D(3,
                                        64,
                                        'basic',
                                        norm='instance',
                                        use_sigmoid=False,
                                        gpu_ids=args.gpu_ids)

        self.netG_AtoB.apply(weights_init_normal)
        self.netG_BtoA.apply(weights_init_normal)
        self.netD_A.apply(weights_init_normal)
        self.netD_B.apply(weights_init_normal)

        checkpoint_AtoB_filename = 'netG_A2B.pth'
        checkpoint_BtoA_filename = 'netG_B2A.pth'

        checkpoint_D_A_filename = 'netD_A.pth'
        checkpoint_D_B_filename = 'netD_B.pth'

        checkpoint_path_AtoB = os.path.join(args.checkpoints_dir,
                                            checkpoint_AtoB_filename)
        checkpoint_path_BtoA = os.path.join(args.checkpoints_dir,
                                            checkpoint_BtoA_filename)

        checkpoint_path_D_A = os.path.join(args.checkpoints_dir,
                                           checkpoint_D_A_filename)
        checkpoint_path_D_B = os.path.join(args.checkpoints_dir,
                                           checkpoint_D_B_filename)

        # Load checkpoint
        # self.netG_AtoB.load_state_dict(torch.load(checkpoint_path_AtoB))
        # self.netG_BtoA.load_state_dict(torch.load(checkpoint_path_BtoA))
        # self.netD_A.load_state_dict(torch.load(checkpoint_path_D_A))
        # self.netD_B.load_state_dict(torch.load(checkpoint_path_D_B))

        # define loss
        # self.criterionGAN = networks.GANLoss().to(self.device)
        self.criterionGAN = torch.nn.MSELoss().cuda()
        self.criterionCycle = torch.nn.L1Loss().cuda()
        self.criterionIdentity = torch.nn.L1Loss().cuda()

        # init optimizer
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_AtoB.parameters(), self.netG_BtoA.parameters()),
                                            lr=0.0001,
                                            betas=(0.5, 0.999))
        self.optimizer_D_a = torch.optim.Adam(self.netD_A.parameters(),
                                              lr=0.0001,
                                              betas=(0.5, 0.999))
        self.optimizer_D_b = torch.optim.Adam(self.netD_B.parameters(),
                                              lr=0.0001,
                                              betas=(0.5, 0.999))

        self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=LambdaLR(args.n_epochs, args.epoch,
                               args.decay_epoch).step)
        self.lr_scheduler_D_a = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D_a,
            lr_lambda=LambdaLR(args.n_epochs, args.epoch,
                               args.decay_epoch).step)
        self.lr_scheduler_D_b = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D_b,
            lr_lambda=LambdaLR(args.n_epochs, args.epoch,
                               args.decay_epoch).step)

    def set_input(self, input_real, input_fake):
        self.image_real_sizes = input_real['A_sizes']

        input_A = input_real['A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.image_real_paths = input_real['A_paths']

        # self.size_real = (int(self.image_real_sizes[0]), int(self.image_real_sizes[1]))

        self.image_fake_sizes = input_fake['B_sizes']

        input_B = input_fake['B']
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_fake_paths = input_fake['B_paths']

        # self.size_fake = (int(self.image_fake_sizes[0]), int(self.image_fake_sizes[1]))

    def train(self):
        real_A = Variable(self.input_A)
        real_B = Variable(self.input_B)
        target_real = Variable(self.Tensor(real_B.size(0), 1, 14,
                                           62).fill_(1.0),
                               requires_grad=False)
        target_fake = Variable(self.Tensor(real_B.size(0), 1, 14,
                                           62).fill_(0.0),
                               requires_grad=False)
        loss_gan = self.criterionGAN
        loss_cycle = self.criterionCycle
        loss_identity = self.criterionIdentity

        self.optimizer_G.zero_grad()

        i_b = self.netG_AtoB(real_B)
        loss_identity_B = loss_identity(i_b, real_B) * 0.5
        i_a = self.netG_BtoA(real_A)
        loss_identity_A = loss_identity(i_a, real_A) * 0.5

        fake_B = self.netG_AtoB(real_A)
        pred_fake = self.netD_B(fake_B)
        loss_gan_A2B = loss_gan(pred_fake, target_real)
        fake_A = self.netG_BtoA(real_B)
        pred_fake = self.netD_A(fake_A)
        loss_gan_B2A = loss_gan(pred_fake, target_real)

        recovered_a = self.netG_BtoA(fake_B)
        loss_cycle_A = loss_cycle(recovered_a, real_A) * 10.0
        recovered_b = self.netG_AtoB(fake_A)
        loss_cycle_B = loss_cycle(recovered_b, real_B) * 10.0

        loss_G = loss_identity_A + loss_identity_B + loss_gan_A2B + loss_gan_B2A + loss_cycle_A + loss_cycle_B
        loss_G.backward()

        self.optimizer_G.step()

        self.optimizer_D_a.zero_grad()

        pred_real = self.netD_A(real_A)
        loss_d_real = loss_gan(pred_real, target_real)
        fake_A = self.fake_A_Buffer.push_and_pop(fake_A)
        pred_fake = self.netD_A(fake_A.detach())
        loss_d_fake = loss_gan(pred_fake, target_fake)

        loss_d_a = (loss_d_real + loss_d_fake) * 0.5
        loss_d_a.backward()

        self.optimizer_D_a.step()

        self.optimizer_D_b.zero_grad()

        pred_real = self.netD_B(real_B)
        loss_d_real = loss_gan(pred_real, target_real)
        fake_B = self.fake_B_Buffer.push_and_pop(fake_B)
        pred_fake = self.netD_B(fake_B.detach())
        loss_d_fake = loss_gan(pred_fake, target_fake)

        loss_d_b = (loss_d_real + loss_d_fake) * 0.5
        loss_d_b.backward()

        self.optimizer_D_b.step()

        print(
            'Generator Total Loss : {a:.3f},   Generator Identity Loss : {b:.3f},   Generator GAN Loss : {c:.3f},   '
            'Generator Cycle Loss : {d:.3f}'.format(
                a=loss_G,
                b=loss_identity_A + loss_identity_B,
                c=loss_gan_A2B + loss_gan_B2A,
                d=loss_cycle_A + loss_cycle_B))
        print('Discriminator Loss : {a:.3f}'.format(a=loss_d_a + loss_d_b))

    def update_learning_rate(self):
        self.lr_scheduler_G.step()
        self.lr_scheduler_D_a.step()
        self.lr_scheduler_D_b.step()

    def save_checkpoint(self):
        torch.save(self.netG_AtoB.state_dict(), './checkpoints/netG_A2B.pth')
        torch.save(self.netG_BtoA.state_dict(), './checkpoints/netG_B2A.pth')
        torch.save(self.netD_A.state_dict(), './checkpoints/netD_A.pth')
        torch.save(self.netD_B.state_dict(), './checkpoints/netD_B.pth')

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG_AtoB(self.real_A)

    def get_image_paths(self):
        return self.image_real_paths, self.image_fake_paths

    def get_image_sizes(self):
        return self.size_real, self.size_fake

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)

        return OrderedDict([('original', real_A), ('restyled', fake_B)])
Exemple #2
0
class GanModel(BaseModel):
    def name(self):
        return 'TrainGanModel'

    def initialize(self, args):
        BaseModel.initialize(self, args)
        self.input_B = self.Tensor(args.batchSize, 3, 1024, 256)
        self.input_C = self.Tensor(args.batchSize, 1, 1024, 256)

        self.fake_Buffer = ReplayBuffer()

        self.netG_BtoC = networks.define_G(3, 1, 64, 'unet_128', 'batch',
                                           False, args.init_type, self.gpu_ids)
        self.netD_C = networks.define_D(1,
                                        64,
                                        'basic',
                                        norm='batch',
                                        use_sigmoid=False,
                                        gpu_ids=args.gpu_ids)

        self.netG_BtoC.apply(weights_init_normal)
        self.netD_C.apply(weights_init_normal)

        checkpoint_BtoC_filename = 'netG_B2C.pth'
        checkpoint_D_C_filename = 'netD_C.pth'

        checkpoint_path_BtoC = os.path.join(args.checkpoints_dir,
                                            checkpoint_BtoC_filename)
        checkpoint_path_D_C = os.path.join(args.checkpoints_dir,
                                           checkpoint_D_C_filename)

        # Load checkpoint
        # self.netG_BtoC.load_state_dict(torch.load(checkpoint_path_BtoC))
        # self.netD_C.load_state_dict(torch.load(checkpoint_path_D_C))

        # define loss
        self.criterionGAN = torch.nn.MSELoss()
        self.criterionReconstruction = torch.nn.L1Loss().cuda()

        # init optimizer
        self.optimizer_G = torch.optim.Adam(self.netG_BtoC.parameters(),
                                            lr=0.0002,
                                            betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD_C.parameters(),
                                            lr=0.0002,
                                            betas=(0.5, 0.999))

        self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=LambdaLR(args.n_epochs, args.epoch,
                               args.decay_epoch).step)
        self.lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D,
            lr_lambda=LambdaLR(args.n_epochs, args.epoch,
                               args.decay_epoch).step)

    def set_input(self, input):
        self.image_syn_sizes = input['B_sizes']

        input_B = input['B']
        save_image(input_B[0], './input_check/rgb.jpg')
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_syn_paths = input['B_paths']

        # self.size_syn = (int(self.image_syn_sizes[0]), int(self.image_syn_sizes[1]))

        self.image_dep_sizes = input['C_sizes']

        input_C = input['C']
        save_image(input_C[0], './input_check/depth.jpg')
        self.input_C.resize_(input_C.size()).copy_(input_C)
        self.image_dep_paths = input['C_paths']

        # self.size_dep = (int(self.image_dep_sizes[0]), int(self.image_dep_sizes[1]))

    def train(self):
        syn_data = Variable(self.input_B)
        dep_data = Variable(self.input_C)
        target_real = Variable(self.Tensor(syn_data.size(0), 1, 14,
                                           62).fill_(1.0),
                               requires_grad=False)
        target_fake = Variable(self.Tensor(syn_data.size(0), 1, 14,
                                           62).fill_(0.0),
                               requires_grad=False)
        loss_gan = self.criterionGAN
        loss_rec = self.criterionReconstruction

        self.optimizer_G.zero_grad()

        fake_dep = self.netG_BtoC(syn_data)
        loss_r = loss_rec(fake_dep, dep_data)
        loss_g = loss_gan(self.netD_C(fake_dep), target_real)
        loss_G = 0.01 * loss_g + 0.99 * loss_r
        # loss_G = loss_g
        loss_G.backward()

        self.optimizer_G.step()

        self.optimizer_D.zero_grad()

        pred_real = self.netD_C(dep_data)
        loss_real = loss_gan(pred_real, target_real)
        fake_A = self.fake_Buffer.push_and_pop(fake_dep)
        pred_fake = self.netD_C(fake_A)
        loss_fake = loss_gan(pred_fake, target_fake)

        loss_D = (loss_real + loss_fake) * 0.5
        loss_D.backward()

        self.optimizer_D.step()

        print(
            'Generator Loss : {loss_G:.5f}, Discriminator Loss : {loss_D:.5f}'.
            format(loss_G=loss_G, loss_D=loss_D))

    def update_learning_rate(self):
        self.lr_scheduler_G.step()
        self.lr_scheduler_D.step()

    def save_checkpoint(self):
        torch.save(self.netG_BtoC.state_dict(), './checkpoints/netG_B2C.pth')
        torch.save(self.netD_C.state_dict(), './checkpoints/netD_C.pth')

    def forward(self):
        self.syn_data = Variable(self.input_B)
        self.pred_depth = self.netG_BtoC(self.syn_data)

    def get_image_paths(self):
        return self.image_syn_paths, self.image_dep_paths

    def get_image_sizes(self):
        return self.size_syn, self.size_dep

    def get_current_visuals(self):
        syn_d = util.tensor2im(self.syn_data.data)
        pred_d = util.tensor2im(self.pred_depth.data)

        return OrderedDict([('original', syn_d), ('depth', pred_d)])