Beispiel #1
0
    def train(self,args):
        # For transforming the input image
        transform = transforms.Compose(
            [transforms.RandomHorizontalFlip(),
             transforms.Resize((args.img_height,args.img_width)),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Pytorch dataloader
        a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainA'], transform=transform), 
                                                                    batch_size=args.batch_size, shuffle=True, num_workers=4)
        b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainB'], transform=transform), 
                                                                    batch_size=args.batch_size, shuffle=True, num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):
            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                # step
                step = epoch * min(len(a_loader), len(b_loader)) + i + 1

                # set train
                self.Gab.train()
                self.Gba.train()

                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                # Forward pass through generators
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                # Adversarial losses
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(a_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                # Cycle consistency losses
                a_cycle_loss = self.L1(a_recon, a_real)
                b_cycle_loss = self.L1(b_recon, b_real)

                # Total generators losses
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss * args.lamda + b_cycle_loss * args.lamda

                # Update generators
                self.Gab.zero_grad()
                self.Gba.zero_grad()
                gen_loss.backward()
                self.gab_optimizer.step()
                self.gba_optimizer.step()

                # Sample from history of generated images
                a_fake = Variable(torch.Tensor(a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(torch.Tensor(b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators 
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(a_real_dis.size())))
                fake_label = utils.cuda(Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = a_dis_real_loss + a_dis_fake_loss
                b_dis_loss = b_dis_real_loss + b_dis_fake_loss

                # Update discriminators
                self.Da.zero_grad()
                self.Db.zero_grad()
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.da_optimizer.step()
                self.db_optimizer.step()

                print("Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" % 
                                            (epoch, i + 1, min(len(a_loader), len(b_loader)),
                                                            gen_loss,a_dis_loss+b_dis_loss))

            # Override the latest checkpoint 
            utils.save_checkpoint({'epoch': epoch + 1,
                                   'Da': self.Da.state_dict(),
                                   'Db': self.Db.state_dict(),
                                   'Gab': self.Gab.state_dict(),
                                   'Gba': self.Gba.state_dict(),
                                   'da_optimizer': self.da_optimizer.state_dict(),
                                   'db_optimizer': self.db_optimizer.state_dict(),
                                   'gab_optimizer': self.gab_optimizer.state_dict(),
                                   'gba_optimizer': self.gba_optimizer.state_dict()},
                                  '%s/latest.ckpt' % (args.checkpoint_dir))
Beispiel #2
0
    def train(self, args):
        # For transforming the input image
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Pytorch dataloader
        a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainA'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainB'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        max_len = max(len(a_loader), len(b_loader))

        steps = 0
        for epoch in range(self.start_epoch, args.epochs):
            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            a_it = iter(a_loader)
            b_it = iter(b_loader)

            for i in range(max_len):
                try:
                    a_real = next(a_it)[0]
                except:
                    a_it = iter(a_loader)

                try:
                    b_real = next(b_it)[0]
                except:
                    b_it = iter(b_loader)

                # Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real)
                b_real = Variable(b_real)
                a_real, b_real = utils.cuda([a_real, b_real])

                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                # Identity losses
                ###################################################
                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # lamda = 1.75E+12
                lamda = args.lamda * args.idt_coef
                a_idt_loss = self.L1(a_idt, a_real) * lamda
                b_idt_loss = self.L1(b_idt, b_real) * lamda

                # a_real_features = vgg.get_features(a_real)
                # b_real_features = vgg.get_features(b_real)
                # a_fake_features = vgg.get_features(a_fake)
                # b_fake_features = vgg.get_features(b_fake)

                # Content losses
                # content_loss_weight = 1.50
                # content_loss_weight = 1
                # a_content_loss = vgg.get_content_loss(b_fake_features, a_real_features) * content_loss_weight
                # b_content_loss = vgg.get_content_loss(a_fake_features, b_real_features) * content_loss_weight

                # style losse
                # style_loss_weight = 3.00E+05
                # style_loss_weight = 1

                # a_style_loss = vgg.get_style_loss(a_fake_features, a_real_features) * style_loss_weight
                # b_style_loss = vgg.get_style_loss(b_fake_features, b_real_features) * style_loss_weight

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                # gen_loss_weight = 4.50E+08
                gen_loss_weight = 1
                a_gen_loss = self.MSE(a_fake_dis, real_label) * gen_loss_weight
                b_gen_loss = self.MSE(b_fake_dis, real_label) * gen_loss_weight

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda
                # lamda = 3.50E+12
                # a_cycle_loss = self.L1(a_recon, a_real) * lamda
                # b_cycle_loss = self.L1(b_recon, b_real) * lamda

                # gen_loss = a_gen_loss + b_gen_loss +\
                #            a_cycle_loss + b_cycle_loss +\
                #            a_style_loss + b_style_loss +\
                #            a_content_loss + b_content_loss +\
                #            a_idt_loss + b_idt_loss

                # # Total generators losses
                # ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss
                # # Update generators
                # ###################################################
                gen_loss.backward()
                self.g_optimizer.step()
                #
                #
                # Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                steps += 1
                if steps % print_msg == 0:
                    print(
                        "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e"
                        % (epoch, i + 1, max(len(a_loader), len(b_loader)),
                           gen_loss, a_dis_loss + b_dis_loss))

            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
Beispiel #3
0
    def train(self, args):
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        a_loader = DataLoader(dsets.ImageFolder(dataset_dirs['trainA'],
                                                transform=transform),
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=4)
        b_loader = DataLoader(dsets.ImageFolder(dataset_dirs['trainB'],
                                                transform=transform),
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                a_real = a_real[0]
                b_real = b_real[0]
                a_real, b_real = utils.cuda([a_real, b_real])

                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                a_idt_loss = self.L1(a_idt, a_real) * 5.0
                b_idt_loss = self.L1(b_idt, b_real) * 5.0

                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(torch.ones(a_fake_dis.size()))

                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                a_cycle_loss = self.L1(a_recon, a_real) * 10.0
                b_cycle_loss = self.L1(b_recon, b_real) * 10.0

                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                gen_loss.backward()
                self.g_optimizer.step()

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                a_fake = torch.Tensor(
                    a_fake_sample([a_fake.cpu().data.numpy()])[0])
                b_fake = torch.Tensor(
                    b_fake_sample([b_fake.cpu().data.numpy()])[0])
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(torch.ones(a_real_dis.size()))
                fake_label = utils.cuda(torch.zeros(a_fake_dis.size()))

                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %
                    (epoch, i + 1, min(len(a_loader), len(b_loader)), gen_loss,
                     a_dis_loss + b_dis_loss))

            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
Beispiel #4
0
    def train(self, args):
        # For transforming the input image
        transform = transforms.Compose([
            # [transforms.RandomHorizontalFlip(),
            transforms.Resize((480, 1440)),
            #  transforms.RandomCrop((args.crop_height,args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Pytorch dataloader
        dataset = torch.utils.data.DataLoader(dsets.ImageFolder(
            '/train_merged/', transform=transform),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=4)
        # a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainA'], transform=transform),
        #                                                 batch_size=args.batch_size, shuffle=True, num_workers=4)
        # b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainB'], transform=transform),
        #                                                 batch_size=args.batch_size, shuffle=True, num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, x in enumerate(dataset):
                # step
                step = epoch * len(dataset) + i + 1

                # Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                x = Variable(x[0])
                x = utils.cuda([x])[0]
                shape = x.shape
                a_real, b_real = x[:, :, :, shape[3] // 2:], x[:, :, :,
                                                               shape[3] // 2:]
                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity losses
                ###################################################
                a_idt_loss = self.L1(a_idt,
                                     a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.L1(b_idt,
                                     b_real) * args.lamda * args.idt_coef

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda

                # Total generators losses
                ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                ###################################################
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %
                    (epoch, i + 1, len(dataset), gen_loss,
                     a_dis_loss + b_dis_loss))

            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
Beispiel #5
0
    def train(self, args):
        # Image transforms
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Initialize dataloader
        a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainA'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainB'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        # live plot loss
        Gab_history = hl.History()
        Gba_history = hl.History()
        gan_history = hl.History()
        Da_history = hl.History()
        Db_history = hl.History()

        canvas = hl.Canvas()

        for epoch in range(self.start_epoch, args.epochs):
            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):

                # Identify step
                step = epoch * min(len(a_loader), len(b_loader)) + i + 1

                # Generators ===============================================================
                # Turning off grads for discriminators
                set_grad([self.Da, self.Db], False)

                # Zero out grads of the generator
                self.g_optimizer.zero_grad()

                # Real images from sets A and B
                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                # Passing through generators
                # Nomenclature. a_fake is fake image generated from b_real in the domain A.
                # NOTE: Gab generate a from b and vice versa
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                # Both generators should be able to generate the image in its own domain
                # give an input from its own domain
                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity loss
                a_idt_loss = self.L1(a_idt, a_real) * args.delta
                b_idt_loss = self.L1(b_idt, b_real) * args.delta

                # Adverserial loss
                # Da return 1 for an image in domain A
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                # Label expected here is 1 to fool the discriminator
                expected_label_a = utils.cuda(
                    Variable(torch.ones(a_fake_dis.size())))
                expected_label_b = utils.cuda(
                    Variable(torch.ones(b_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, expected_label_a)
                b_gen_loss = self.MSE(b_fake_dis, expected_label_b)

                # Cycle Consistency loss
                a_cycle_loss = self.L1(a_recon, a_real) * args.alpha
                b_cycle_loss = self.L1(b_recon, b_real) * args.alpha

                # Structural Cycle Consistency loss
                a_scyc_loss = self.ssim(a_recon, a_real) * args.beta
                b_scyc_loss = self.ssim(b_recon, b_real) * args.beta

                # Structure similarity loss
                # ba refers to the ssim scores between input and output generated by gen_ba
                # the gray image values range is 0-1
                gray = kornia.color.RgbToGrayscale()
                a_real_gray = gray((a_real + 1) / 2.0)
                a_fake_gray = gray((a_fake + 1) / 2.0)
                a_recon_gray = gray((a_recon + 1) / 2.0)
                b_real_gray = gray((b_real + 1) / 2.0)
                b_fake_gray = gray((b_fake + 1) / 2.0)
                b_recon_gray = gray((b_recon + 1) / 2.0)

                ba_ssim_loss = (
                    (self.ssim(a_real_gray, b_fake_gray)) +
                    (self.ssim(a_fake_gray, b_recon_gray))) * args.gamma
                ab_ssim_loss = (
                    (self.ssim(b_real_gray, a_fake_gray)) +
                    (self.ssim(b_fake_gray, a_recon_gray))) * args.gamma

                # Total Generator Loss
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_scyc_loss + b_scyc_loss + a_idt_loss + b_idt_loss + ba_ssim_loss + ab_ssim_loss

                # Update Generators
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminators ===========================================================
                # Turn on grads for discriminators
                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from previously generated fake images
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Pass through discriminators
                # Discriminator for domain A
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)

                # Discriminator for domain B
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)

                # Expected label for real image is 1
                exp_real_label_a = utils.cuda(
                    Variable(torch.ones(a_real_dis.size())))
                exp_fake_label_a = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                exp_real_label_b = utils.cuda(
                    Variable(torch.ones(b_real_dis.size())))
                exp_fake_label_b = utils.cuda(
                    Variable(torch.zeros(b_fake_dis.size())))

                # Discriminator losses
                a_real_dis_loss = self.MSE(a_real_dis, exp_real_label_a)
                a_fake_dis_loss = self.MSE(a_fake_dis, exp_fake_label_a)
                b_real_dis_loss = self.MSE(b_real_dis, exp_real_label_b)
                b_fake_dis_loss = self.MSE(b_fake_dis, exp_fake_label_b)

                # Total discriminator loss
                a_dis_loss = (a_fake_dis_loss + a_real_dis_loss) / 2
                b_dis_loss = (b_fake_dis_loss + b_real_dis_loss) / 2

                # Update discriminators
                a_dis_loss.backward()
                b_dis_loss.backward()

                self.d_optimizer.step()

                if i % args.log_freq == 0:
                    # Log losses
                    Gab_history.log(step,
                                    gen_loss=a_gen_loss,
                                    cycle_loss=a_cycle_loss,
                                    idt_loss=a_idt_loss,
                                    ssim_loss=ab_ssim_loss,
                                    scyc_loss=a_scyc_loss)

                    Gba_history.log(step,
                                    gen_loss=b_gen_loss,
                                    cycle_loss=b_cycle_loss,
                                    idt_loss=b_idt_loss,
                                    ssim_loss=ba_ssim_loss,
                                    scyc_loss=b_scyc_loss)

                    Da_history.log(step,
                                   loss=a_dis_loss,
                                   fake_loss=a_fake_dis_loss,
                                   real_loss=a_real_dis_loss)

                    Db_history.log(step,
                                   loss=b_dis_loss,
                                   fake_loss=b_fake_dis_loss,
                                   real_loss=b_real_dis_loss)

                    gan_history.log(step,
                                    gen_loss=gen_loss,
                                    dis_loss=(a_dis_loss + b_dis_loss))

                    print(
                        "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e"
                        % (epoch, i + 1, min(len(a_loader), len(b_loader)),
                           gen_loss, a_dis_loss + b_dis_loss))
                    with canvas:
                        canvas.draw_plot([
                            Gba_history['gen_loss'], Gba_history['cycle_loss'],
                            Gba_history['idt_loss'], Gba_history['ssim_loss'],
                            Gba_history['scyc_loss']
                        ],
                                         labels=[
                                             'Adv loss', 'Cycle loss',
                                             'Identity loss', 'SSIM',
                                             'SCyC loss'
                                         ])

                        canvas.draw_plot([
                            Gab_history['gen_loss'], Gab_history['cycle_loss'],
                            Gab_history['idt_loss'], Gab_history['ssim_loss'],
                            Gab_history['scyc_loss']
                        ],
                                         labels=[
                                             'Adv loss', 'Cycle loss',
                                             'Identity loss', 'SSIM',
                                             'SCyC loss'
                                         ])

                        canvas.draw_plot(
                            [
                                Db_history['loss'], Db_history['fake_loss'],
                                Db_history['real_loss']
                            ],
                            labels=['Loss', 'Fake Loss', 'Real Loss'])

                        canvas.draw_plot(
                            [
                                Da_history['loss'], Da_history['fake_loss'],
                                Da_history['real_loss']
                            ],
                            labels=['Loss', 'Fake Loss', 'Real Loss'])

                        canvas.draw_plot(
                            [gan_history['gen_loss'], gan_history['dis_loss']],
                            labels=['Generator loss', 'Discriminator loss'])

            # Overwrite checkpoint
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_path))

            # Save loss history
            history_path = args.results_path + '/loss_history/'
            utils.mkdir([history_path])
            Gab_history.save(history_path + "Gab.pkl")
            Gba_history.save(history_path + "Gba.pkl")
            Da_history.save(history_path + "Da.pkl")
            Db_history.save(history_path + "Db.pkl")
            gan_history.save(history_path + "gan.pkl")

            # Update learning rates
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()

            # Run one test cycle
            if args.testing:
                print('Testing')
                tst.test(args, epoch)
Beispiel #6
0
    def train(self, args):
        # Test input
        transform_test = transforms.Compose([
            transforms.Resize((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_testdata_link(args.dataset_dir)

        a_test_data = dsets.ImageFolder(dataset_dirs['testA'],
                                        transform=transform_test)
        b_test_data = dsets.ImageFolder(dataset_dirs['testB'],
                                        transform=transform_test)

        a_test_loader = torch.utils.data.DataLoader(a_test_data,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=4)
        b_test_loader = torch.utils.data.DataLoader(b_test_data,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=4)

        # For transforming the input image
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Pytorch dataloader
        a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainA'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainB'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                # step
                step = epoch * min(len(a_loader), len(b_loader)) + i + 1

                # Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity losses
                ###################################################
                a_idt_loss = self.L1(a_idt,
                                     a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.L1(b_idt,
                                     b_real) * args.lamda * args.idt_coef

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda

                # Total generators losses
                ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                ###################################################
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %
                    (epoch, i + 1, min(len(a_loader), len(b_loader)), gen_loss,
                     a_dis_loss + b_dis_loss))

            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))
            # Save image current :
            #######################################################################
            """ run """
            a_real_test = Variable(iter(a_test_loader).next()[0],
                                   requires_grad=True)
            b_real_test = Variable(iter(b_test_loader).next()[0],
                                   requires_grad=True)
            a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])

            self.Gab.eval()
            self.Gba.eval()

            with torch.no_grad():
                a_fake_test = self.Gab(b_real_test)
                b_fake_test = self.Gba(a_real_test)
                a_recon_test = self.Gab(b_fake_test)
                b_recon_test = self.Gba(a_fake_test)

            pic = (torch.cat([
                a_real_test, b_fake_test, a_recon_test, b_real_test,
                a_fake_test, b_recon_test
            ],
                             dim=0).data + 1) / 2.0

            if not os.path.isdir(args.results_dir):
                os.makedirs(args.results_dir)

            torchvision.utils.save_image(pic,
                                         args.results_dir +
                                         '/sample_{}.jpg'.format(epoch),
                                         nrow=3)

            self.Gab.train()
            self.Gba.train()
            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
Beispiel #7
0
    def train(self, args):
        # data transformation
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Dataloader for class A and B
        a_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            dataset_dirs['trainA'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            dataset_dirs['trainB'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        # get fake samples from the sample pool
        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                '''
                Generator First, Discriminator Second
                '''
                # Generator Optimization
                set_grad([self.D_A, self.D_B], False)
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                a_fake = self.G_BtoA(b_real)
                b_fake = self.G_AtoB(a_real)

                a_recon = self.G_BtoA(b_fake)
                b_recon = self.G_AtoB(a_fake)

                a_idt = self.G_BtoA(a_real)
                b_idt = self.G_AtoB(b_real)

                # Identity losses
                a_idt_loss = self.L1(a_idt,
                                     a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.L1(b_idt,
                                     b_real) * args.lamda * args.idt_coef

                # Adversarial losses
                a_fake_dis = self.D_A(a_fake)
                b_fake_dis = self.D_B(b_fake)

                a_real_label = utils.cuda(
                    Variable(torch.ones(a_fake_dis.size())))
                b_real_label = utils.cuda(
                    Variable(torch.ones(b_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, a_real_label)
                b_gen_loss = self.MSE(b_fake_dis, b_real_label)

                # Cycle consistency losses
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda

                # Total generators losses
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminator Optimization
                set_grad([self.D_A, self.D_B], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                a_real_dis = self.D_A(a_real)
                a_fake_dis = self.D_A(a_fake)
                b_real_dis = self.D_B(b_real)
                b_fake_dis = self.D_B(b_fake)

                a_real_label = utils.cuda(
                    Variable(torch.ones(a_real_dis.size())))
                a_fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))
                b_real_label = utils.cuda(
                    Variable(torch.ones(b_real_dis.size())))
                b_fake_label = utils.cuda(
                    Variable(torch.zeros(b_fake_dis.size())))

                # Discriminator losses
                a_dis_real_loss = self.MSE(a_real_dis, a_real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, a_fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, b_real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, b_fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                # print some information
                if (i + 1) % 20 == 0:
                    print(
                        "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e"
                        % (epoch, i + 1, min(len(a_loader), len(b_loader)),
                           gen_loss, a_dis_loss + b_dis_loss))

            # Update the checkpoint
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'D_A': self.D_A.state_dict(),
                    'D_B': self.D_B.state_dict(),
                    'G_AtoB': self.G_AtoB.state_dict(),
                    'G_BtoA': self.G_BtoA.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
Beispiel #8
0
    def train(self, args):
        # For transforming the input image
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            # transforms.Resize((args.load_height,args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)
        dataset_a = ListDataSet(
            '/media/l/新加卷1/city/data/river/train_256_9w.lst',
            transform=transform)
        dataset_b = ListDataSet('/media/l/新加卷/city/jinan_z3.lst',
                                transform=transform)
        # Pytorch dataloader
        a_loader = torch.utils.data.DataLoader(dataset_a,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dataset_b,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                # step
                step = epoch * min(len(a_loader), len(b_loader)) + i + 1

                # Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity losses
                ###################################################
                a_idt_loss = self.L1(a_idt,
                                     a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.L1(b_idt,
                                     b_real) * args.lamda * args.idt_coef

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda

                # Total generators losses
                ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                ###################################################
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.4f | Dis Loss:%.4f" %
                    (epoch, i + 1, min(len(a_loader), len(b_loader)), gen_loss,
                     a_dis_loss + b_dis_loss))

            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict(),
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
Beispiel #9
0
    def train(self, args):
        # For transforming the input image
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        test_transform = transforms.Compose([
            transforms.Resize((args.test_crop_height, args.test_crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)
        testset_dirs = utils.get_testdata_link(args.dataset_dir)

        # Pytorch dataloader
        a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainA'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainB'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_test_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            testset_dirs['testA'], transform=test_transform),
                                                    batch_size=1,
                                                    shuffle=False,
                                                    num_workers=4)
        b_test_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            testset_dirs['testB'], transform=test_transform),
                                                    batch_size=1,
                                                    shuffle=False,
                                                    num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            if epoch >= 1:
                print('generating test result...')
                self.save_sample_image(args.test_length, a_test_loader,
                                       b_test_loader, args.results_dir, epoch)

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)
            running_Gen_loss = 0
            running_Dis_loss = 0
            ##################################################
            # BEGIN TRAINING FOR ONE EPOCH
            ##################################################
            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                # step
                step = epoch * min(len(a_loader), len(b_loader)) + i + 1

                ##################################################
                # Part 1: Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity losses
                ###################################################
                a_idt_loss = self.identity_criteron(
                    a_idt, a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.identity_criteron(
                    b_idt, b_real) * args.lamda * args.idt_coef
                # a_idt_loss = 0
                # b_idt_loss = 0

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                a_gen_loss = self.adversarial_criteron(a_fake_dis, real_label)
                b_gen_loss = self.adversarial_criteron(b_fake_dis, real_label)

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.cycle_consistency_criteron(
                    a_recon, a_real) * args.lamda
                b_cycle_loss = self.cycle_consistency_criteron(
                    b_recon, b_real) * args.lamda

                # Total generators losses
                ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                ###################################################
                gen_loss.backward()
                self.g_optimizer.step()

                ##################################################
                # Part 2: Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.adversarial_criteron(
                    a_real_dis, real_label)
                a_dis_fake_loss = self.adversarial_criteron(
                    a_fake_dis, fake_label)
                b_dis_real_loss = self.adversarial_criteron(
                    b_real_dis, real_label)
                b_dis_fake_loss = self.adversarial_criteron(
                    b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %
                    (epoch, i + 1, min(len(a_loader), len(b_loader)), gen_loss,
                     a_dis_loss + b_dis_loss))
                running_Gen_loss += gen_loss
                running_Dis_loss += (a_dis_loss + b_dis_loss)
            ##################################################
            # END TRAINING FOR ONE EPOCH
            ##################################################
            self.writer.add_scalar(
                'Gen Loss',
                running_Gen_loss / min(len(a_loader), len(b_loader)), epoch)
            self.writer.add_scalar(
                'Dis Loss',
                running_Dis_loss / min(len(a_loader), len(b_loader)), epoch)
            self.writer.add_scalar('Gen_LR',
                                   self.g_lr_scheduler.get_lr()[0], epoch)
            self.writer.add_scalar('Dis_LR',
                                   self.d_lr_scheduler.get_lr()[0], epoch)
            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()

        self.writer.close()