コード例 #1
0
    def test(self, testloader, cur_epoch=-1):
        loss, top1, top5 = AverageTracker(), AverageTracker(), AverageTracker()

        # Set the model to be in testing mode (for dropout and batchnorm)
        self.model.eval()

        for data, target in testloader:
            if self.args.cuda:
                data, target = data.cuda(), target.cuda()
            data_var, target_var = Variable(data, volatile=True), Variable(
                target, volatile=True)

            # Forward pass
            output = self.model(data_var)
            cur_loss = self.loss(output, target_var)

            # Top-1 and Top-5 Accuracy Calculation
            cur_acc1, cur_acc5 = self.compute_accuracy(output.data,
                                                       target,
                                                       topk=(1, 5))
            loss.update(cur_loss.data[0])
            top1.update(cur_acc1[0])
            top5.update(cur_acc5[0])

        if cur_epoch != -1:
            # Summary Writing
            self.summary_writer.add_scalar("test-loss", loss.avg, cur_epoch)
            self.summary_writer.add_scalar("test-top-1-acc", top1.avg,
                                           cur_epoch)
            self.summary_writer.add_scalar("test-top-5-acc", top5.avg,
                                           cur_epoch)

        print("Test Results" + " | " + "loss: " + str(loss.avg) +
              " - acc-top1: " + str(top1.avg)[:7] + "- acc-top5: " +
              str(top5.avg)[:7])
コード例 #2
0
    def train(self):
        for cur_epoch in range(self.start_epoch, self.args.num_epochs):

            # Initialize tqdm
            tqdm_batch = tqdm(self.trainloader,
                              desc="Epoch-" + str(cur_epoch) + "-")

            # Learning rate adjustment
            self.adjust_learning_rate(self.optimizer, cur_epoch)

            # Meters for tracking the average values
            loss, top1, top5 = AverageTracker(), AverageTracker(
            ), AverageTracker()

            # Set the model to be in training mode (for dropout and batchnorm)
            self.model.train()

            for data, target in tqdm_batch:

                if self.args.cuda:
                    data, target = data.cuda(), target.cuda()
                data_var, target_var = Variable(data), Variable(target)

                # Forward pass
                output = self.model(data_var)
                cur_loss = self.loss(output, target_var)

                # Optimization step
                self.optimizer.zero_grad()
                cur_loss.backward()
                self.optimizer.step()

                # Top-1 and Top-5 Accuracy Calculation
                cur_acc1, cur_acc5 = self.compute_accuracy(output.data,
                                                           target,
                                                           topk=(1, 5))
                loss.update(cur_loss.data[0])
                top1.update(cur_acc1[0])
                top5.update(cur_acc5[0])

            # Summary Writing
            self.summary_writer.add_scalar("epoch-loss", loss.avg, cur_epoch)
            self.summary_writer.add_scalar("epoch-top-1-acc", top1.avg,
                                           cur_epoch)
            self.summary_writer.add_scalar("epoch-top-5-acc", top5.avg,
                                           cur_epoch)

            # Print in console
            tqdm_batch.close()
            print("Epoch-" + str(cur_epoch) + " | " + "loss: " +
                  str(loss.avg) + " - acc-top1: " + str(top1.avg)[:7] +
                  "- acc-top5: " + str(top5.avg)[:7])

            # Evaluate on Validation Set
            if cur_epoch % self.args.test_every == 0 and self.valloader:
                self.test(self.valloader, cur_epoch)

            # Checkpointing
            is_best = top1.avg > self.best_top1
            self.best_top1 = max(top1.avg, self.best_top1)
            self.save_checkpoint(
                {
                    'epoch': cur_epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'best_top1': self.best_top1,
                    'optimizer': self.optimizer.state_dict(),
                }, is_best)
コード例 #3
0
ファイル: train.py プロジェクト: aloyschen/Retinanet
def train():
    """
    Introduction
    ------------
        训练Retinanet模型
    """
    train_transform = Augmentation(size=config.image_size)
    # train_dataset = COCODataset(config.coco_train_dir, config.coco_train_annaFile, config.coco_label_file, training = True, transform = train_transform)
    from VOCDataset import build_vocDataset
    train_dataset = build_vocDataset(config.voc_root)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.train_batch,
                                  shuffle=True,
                                  num_workers=2,
                                  collate_fn=train_dataset.collate_fn)
    print("training on {} samples".format(train_dataset.__len__()))
    net = RetinaNet(config.num_classes, pre_train_path=config.resnet50_path)
    net.cuda()
    optimizer = optim.SGD(net.parameters(),
                          lr=config.learning_rate,
                          momentum=0.9,
                          weight_decay=1e-4)
    criterion = MultiBoxLoss(alpha=config.focal_alpha,
                             gamma=config.focal_gamma,
                             num_classes=config.num_classes)
    anchors = Anchor(config.anchor_areas, config.aspect_ratio,
                     config.scale_ratios)
    anchor_boxes = anchors(input_size=config.image_size)
    for epoch in range(config.Epochs):
        batch_time, loc_losses, conf_losses = AverageTracker(), AverageTracker(
        ), AverageTracker()
        net.train()
        net.freeze_bn()
        end = time.time()
        for index, (image, gt_boxes, labels) in enumerate(train_dataloader):
            loc_targets, cls_targets = [], []
            image = image.cuda()
            loc_preds, cls_preds = net(image)
            batch_num = image.shape[0]
            for idx in range(batch_num):
                gt_box = gt_boxes[index]
                label = labels[index]
                loc_target, cls_target = encode(anchor_boxes, gt_box, label)
                loc_targets.append(loc_target)
                cls_targets.append(cls_target)
            loc_targets = torch.stack(loc_targets).cuda()
            cls_targets = torch.stack(cls_targets).cuda()
            loc_loss, cls_loss = criterion(loc_preds, loc_targets, cls_preds,
                                           cls_targets)
            loss = loc_loss + cls_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loc_losses.update(loc_loss.item(), image.size(0))
            conf_losses.update(cls_loss.item(), image.size(0))
            batch_time.update(time.time() - end)
            end = time.time()
            if idx % config.print_freq == 0:
                print(
                    'Epoch: {}/{} Batch: {}/{} loc Loss: {:.4f} {:.4f} conf loss: {:.4f} {:.4f} Time: {:.4f} {:.4f}'
                    .format(epoch, config.Epochs, idx, len(train_dataloader),
                            loc_losses.val, loc_losses.avg, conf_losses.val,
                            conf_losses.avg, batch_time.val, batch_time.avg))
        if epoch % config.save_freq == 0:
            print('save model')
            torch.save(
                net.state_dict(),
                config.model_dir + 'train_model_epoch{}.pth'.format(epoch + 1))
コード例 #4
0
    def train(self):
        all_train_iter_total_loss = []
        all_train_iter_corr_loss = []
        all_train_iter_recover_loss = []
        all_train_iter_change_loss = []
        all_train_iter_gan_loss_gen = []
        all_train_iter_gan_loss_dis = []
        all_val_epo_iou = []
        all_val_epo_acc = []
        iter_num = [0]
        epoch_num = []
        num_batches = len(self.train_dataloader)

        for epoch_i in range(self.start_epoch + 1, self.n_epoch):
            iter_total_loss = AverageTracker()
            iter_corr_loss = AverageTracker()
            iter_recover_loss = AverageTracker()
            iter_change_loss = AverageTracker()
            iter_gan_loss_gen = AverageTracker()
            iter_gan_loss_dis = AverageTracker()
            batch_time = AverageTracker()
            tic = time.time()

            # train
            self.OldLabel_generator.train()
            self.Image_generator.train()
            self.discriminator.train()
            for i, meta in enumerate(self.train_dataloader):

                image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                ), meta[2].cuda()
                recover_pred, feats = self.OldLabel_generator(
                    label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                corr_pred = self.Image_generator(image, feats)

                # -------------------
                # Train Discriminator
                # -------------------
                self.discriminator.set_requires_grad(True)
                self.optimizer_D.zero_grad()

                fake_sample = torch.cat((image, corr_pred), 1).detach()
                real_sample = torch.cat(
                    (image, label2onehot(new_label, cfg.DATASET.N_CLASS)), 1)

                score_fake_d = self.discriminator(fake_sample)
                score_real = self.discriminator(real_sample)

                gan_loss_dis = self.criterion_D(pred_score=score_fake_d,
                                                real_score=score_real)
                gan_loss_dis.backward()
                self.optimizer_D.step()
                self.scheduler_D.step()

                # ---------------
                # Train Generator
                # ---------------
                self.discriminator.set_requires_grad(False)
                self.optimizer_G.zero_grad()

                score_fake = self.discriminator(
                    torch.cat((image, corr_pred), 1))

                total_loss, corr_loss, recover_loss, change_loss, gan_loss_gen = self.criterion_G(
                    corr_pred, recover_pred, score_fake, old_label, new_label)

                total_loss.backward()
                self.optimizer_G.step()
                self.scheduler_G.step()

                iter_total_loss.update(total_loss.item())
                iter_corr_loss.update(corr_loss.item())
                iter_recover_loss.update(recover_loss.item())
                iter_change_loss.update(change_loss.item())
                iter_gan_loss_gen.update(gan_loss_gen.item())
                iter_gan_loss_dis.update(gan_loss_dis.item())
                batch_time.update(time.time() - tic)
                tic = time.time()

                log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, ' \
                      'Total Loss: {:.6f}, Corr Loss: {:.6f}, Recover Loss: {:.6f}, Change Loss: {:.6f}, GAN_G Loss: {:.6f}, GAN_D Loss: {:.6f}'.format(
                    datetime.now(), epoch_i, i, num_batches, batch_time.avg,
                    total_loss.item(), corr_loss.item(), recover_loss.item(), change_loss.item(), gan_loss_gen.item(), gan_loss_dis.item())
                print(log)

                if (i + 1) % 10 == 0:
                    all_train_iter_total_loss.append(iter_total_loss.avg)
                    all_train_iter_corr_loss.append(iter_corr_loss.avg)
                    all_train_iter_recover_loss.append(iter_recover_loss.avg)
                    all_train_iter_change_loss.append(iter_change_loss.avg)
                    all_train_iter_gan_loss_gen.append(iter_gan_loss_gen.avg)
                    all_train_iter_gan_loss_dis.append(iter_gan_loss_dis.avg)
                    iter_total_loss.reset()
                    iter_corr_loss.reset()
                    iter_recover_loss.reset()
                    iter_change_loss.reset()
                    iter_gan_loss_gen.reset()
                    iter_gan_loss_dis.reset()

                    vis.line(X=np.column_stack(
                        np.repeat(np.expand_dims(iter_num, 0), 6, axis=0)),
                             Y=np.column_stack((all_train_iter_total_loss,
                                                all_train_iter_corr_loss,
                                                all_train_iter_recover_loss,
                                                all_train_iter_change_loss,
                                                all_train_iter_gan_loss_gen,
                                                all_train_iter_gan_loss_dis)),
                             opts={
                                 'legend': [
                                     'total_loss', 'corr_loss', 'recover_loss',
                                     'change_loss', 'gan_loss_gen',
                                     'gan_loss_dis'
                                 ],
                                 'linecolor':
                                 np.array([[255, 0, 0], [0, 255, 0],
                                           [0, 0, 255], [255, 255, 0],
                                           [0, 255, 255], [255, 0, 255]]),
                                 'title':
                                 'Train loss of generator and discriminator'
                             },
                             win='Train loss of generator and discriminator')
                    iter_num.append(iter_num[-1] + 1)

            # eval
            self.OldLabel_generator.eval()
            self.Image_generator.eval()
            self.discriminator.eval()
            with torch.no_grad():
                for j, meta in enumerate(self.valid_dataloader):
                    image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                    ), meta[2].cuda()
                    recover_pred, feats = self.OldLabel_generator(
                        label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                    corr_pred = self.Image_generator(image, feats)
                    preds = np.argmax(corr_pred.cpu().detach().numpy().copy(),
                                      axis=1)
                    target = new_label.cpu().detach().numpy().copy()
                    self.running_metrics.update(target, preds)

                    if j == 0:
                        color_map1 = gen_color_map(preds[0, :]).astype(
                            np.uint8)
                        color_map2 = gen_color_map(preds[1, :]).astype(
                            np.uint8)
                        color_map = cv2.hconcat([color_map1, color_map2])
                        cv2.imwrite(
                            os.path.join(
                                self.val_outdir, '{}epoch*{}*{}.png'.format(
                                    epoch_i, meta[3][0], meta[3][1])),
                            color_map)

            score = self.running_metrics.get_scores()
            oa = score['Overall Acc: \t']
            precision = score['Precision: \t'][1]
            recall = score['Recall: \t'][1]
            iou = score['Class IoU: \t'][1]
            miou = score['Mean IoU: \t']
            self.running_metrics.reset()

            epoch_num.append(epoch_i)
            all_val_epo_acc.append(oa)
            all_val_epo_iou.append(miou)
            vis.line(X=np.column_stack(
                np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)),
                     Y=np.column_stack((all_val_epo_acc, all_val_epo_iou)),
                     opts={
                         'legend':
                         ['val epoch Overall Acc', 'val epoch Mean IoU'],
                         'linecolor': np.array([[255, 0, 0], [0, 255, 0]]),
                         'title': 'Validate Accuracy and IoU'
                     },
                     win='validate Accuracy and IoU')

            log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \
                .format(datetime.now(), epoch_i, oa, recall, miou)
            self.logger.info(log)

            state = {
                'epoch': epoch_i,
                "acc": oa,
                "recall": recall,
                "iou": miou,
                'model_G_N': self.OldLabel_generator.state_dict(),
                'model_G_I': self.Image_generator.state_dict(),
                'model_D': self.discriminator.state_dict(),
                'optimizer_G': self.optimizer_G.state_dict(),
                'optimizer_D': self.optimizer_D.state_dict()
            }
            save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints',
                                     '{}epoch.pth'.format(epoch_i))
            torch.save(state, save_path)
コード例 #5
0
    def train(self):
        all_train_iter_total_loss = []
        all_val_epo_iou = []
        all_val_epo_acc = []
        iter_num = [0]
        epoch_num = []
        num_batches = len(self.train_dataloader)

        for epoch_i in range(self.start_epoch + 1, self.n_epoch):
            iter_total_loss = AverageTracker()
            batch_time = AverageTracker()
            tic = time.time()

            # train
            self.Image_generator.train()
            for i, meta in enumerate(self.train_dataloader):

                new_image, new_label = meta[0].cuda(), meta[1].cuda()
                infer_pred = self.Image_generator(new_image)

                # ---------------
                # Train Generator
                # ---------------
                self.optimizer.zero_grad()
                total_loss = self.criterion(infer_pred, new_label)

                total_loss.backward()
                self.optimizer.step()
                self.scheduler.step()

                iter_total_loss.update(total_loss.item())
                batch_time.update(time.time() - tic)
                tic = time.time()

                log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, Generator Total Loss: {:.6f}'.format(
                    datetime.now(), epoch_i, i, num_batches, batch_time.avg, total_loss.item())
                print(log)

                if (i+1) % 10 == 0:
                    all_train_iter_total_loss.append(iter_total_loss.avg)
                    iter_total_loss.reset()

                    vis.line(
                        X=iter_num,
                        Y=all_train_iter_total_loss,
                        opts={'legend': ['total_loss'],
                              'linecolor': np.array([[255, 0, 0]]),
                              'title': 'Train loss of generator'},
                        win='Train loss of generator'
                    )
                    iter_num.append(iter_num[-1] + 1)

            # eval
            self.Image_generator.eval()
            with torch.no_grad():
                for j, meta in enumerate(self.valid_dataloader):
                    new_image, new_label = meta[0].cuda(), meta[1].cuda()
                    infer_pred = self.Image_generator(new_image)
                    preds = np.argmax(infer_pred.cpu().detach().numpy().copy(), axis=1)
                    target = new_label.cpu().detach().numpy().copy()
                    self.running_metrics.update(target, preds)
                    if j == 0:
                        color_map1 = gen_color_map(preds[0, :]).astype(np.uint8)
                        color_map2 = gen_color_map(preds[1, :]).astype(np.uint8)
                        color_map = cv2.hconcat([color_map1, color_map2])
                        cv2.imwrite(os.path.join(self.val_outdir, '{}epoch*{}*{}.png'
                                                 .format(epoch_i, meta[2][0], meta[2][1])), color_map)

            score = self.running_metrics.get_scores()
            oa = score['Overall Acc: \t']
            precision = score['Precision: \t'][1]
            recall = score['Recall: \t'][1]
            iou = score['Class IoU: \t'][1]
            miou = score['Mean IoU: \t']
            self.running_metrics.reset()

            epoch_num.append(epoch_i)
            all_val_epo_acc.append(oa)
            all_val_epo_iou.append(miou)
            vis.line(
                X=np.column_stack(np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)),
                Y=np.column_stack((
                    all_val_epo_acc,
                    all_val_epo_iou)),
                opts={
                    'legend': ['val epoch Overall Acc', 'val epoch Mean IoU'],
                    'linecolor': np.array(
                        [[255, 0, 0],
                         [0, 255, 0]]),
                    'title': 'Validate Accuracy and IoU'
                },
                win='validate Accuracy and IoU'
            )

            log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \
                .format(datetime.now(), epoch_i, oa, recall, miou)
            self.logger.info(log)

            state = {'epoch': epoch_i,
                     "acc": oa,
                     "recall": recall,
                     "iou": miou,
                     'model': self.Image_generator.state_dict(),
                     'optimizer': self.optimizer.state_dict(),}
            save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format(epoch_i))
            torch.save(state, save_path)
コード例 #6
0
    def __init__(self, *args):
        super(DeepUNetTrainer, self).__init__(*args)

        # log file
        if self.args.train:
            ctime = time.ctime().split()

            log_path = './log'
            if not os.path.exists(log_path):
                os.mkdir(log_path)

            log_dir = os.path.join(
                log_path,
                '%s_%s_%s_%s' % (ctime[-1], ctime[1], ctime[2], ctime[3]))
            os.mkdir(log_dir)
            with open(os.path.join(log_dir, 'arg.txt'), 'w') as f:
                f.write(str(args))
            self.log_file = open(os.path.join(log_dir, 'loss.txt'), 'w')

        self.save_path = './data/result'
        if not os.path.exists(self.save_path):
            os.mkdir(self.save_path)

        # build model
        self.generator = DeepUNetPaintGenerator().to(self.device)
        self.discriminator = PatchGAN(sigmoid=self.args.no_mse).to(self.device)

        # set optimizers
        self.optimizers = self._set_optimizers()

        # set loss functions
        self.losses = self._set_losses()

        # set image pooler
        self.image_pool = ImagePooling(50)

        # load pretrained model
        if self.args.pretrainedG != '':
            if self.args.verbose:
                print('load pretrained generator...')
            load_checkpoints(self.args.pretrainedG, self.generator,
                             self.optimizers['G'])
        if self.args.pretrainedD != '':
            if self.args.verbose:
                print('load pretrained discriminator...')
            load_checkpoints(self.args.pretrainedD, self.discriminator,
                             self.optimizers['D'])

        if self.device.type == 'cuda':
            # enable parallel computation
            self.generator = nn.DataParallel(self.generator)
            self.discriminator = nn.DataParallel(self.discriminator)

        # loss values for tracking
        self.loss_G_gan = AverageTracker('loss_G_gan')
        self.loss_G_l1 = AverageTracker('loss_G_l1')
        self.loss_D_real = AverageTracker('loss_D_real')
        self.loss_D_fake = AverageTracker('loss_D_fake')

        # image value
        self.imageA = None
        self.imageB = None
        self.fakeB = None
コード例 #7
0
class DeepUNetTrainer(ModelTrainer):
    def __init__(self, *args):
        super(DeepUNetTrainer, self).__init__(*args)

        # log file
        if self.args.train:
            ctime = time.ctime().split()

            log_path = './log'
            if not os.path.exists(log_path):
                os.mkdir(log_path)

            log_dir = os.path.join(
                log_path,
                '%s_%s_%s_%s' % (ctime[-1], ctime[1], ctime[2], ctime[3]))
            os.mkdir(log_dir)
            with open(os.path.join(log_dir, 'arg.txt'), 'w') as f:
                f.write(str(args))
            self.log_file = open(os.path.join(log_dir, 'loss.txt'), 'w')

        self.save_path = './data/result'
        if not os.path.exists(self.save_path):
            os.mkdir(self.save_path)

        # build model
        self.generator = DeepUNetPaintGenerator().to(self.device)
        self.discriminator = PatchGAN(sigmoid=self.args.no_mse).to(self.device)

        # set optimizers
        self.optimizers = self._set_optimizers()

        # set loss functions
        self.losses = self._set_losses()

        # set image pooler
        self.image_pool = ImagePooling(50)

        # load pretrained model
        if self.args.pretrainedG != '':
            if self.args.verbose:
                print('load pretrained generator...')
            load_checkpoints(self.args.pretrainedG, self.generator,
                             self.optimizers['G'])
        if self.args.pretrainedD != '':
            if self.args.verbose:
                print('load pretrained discriminator...')
            load_checkpoints(self.args.pretrainedD, self.discriminator,
                             self.optimizers['D'])

        if self.device.type == 'cuda':
            # enable parallel computation
            self.generator = nn.DataParallel(self.generator)
            self.discriminator = nn.DataParallel(self.discriminator)

        # loss values for tracking
        self.loss_G_gan = AverageTracker('loss_G_gan')
        self.loss_G_l1 = AverageTracker('loss_G_l1')
        self.loss_D_real = AverageTracker('loss_D_real')
        self.loss_D_fake = AverageTracker('loss_D_fake')

        # image value
        self.imageA = None
        self.imageB = None
        self.fakeB = None

    def train(self, last_iteration):
        """
        Run single epoch
        """
        average_trackers = [
            self.loss_G_gan, self.loss_D_fake, self.loss_D_real, self.loss_G_l1
        ]
        self.generator.train()
        self.discriminator.train()
        for tracker in average_trackers:
            tracker.initialize()
        for i, datas in enumerate(self.data_loader, last_iteration):
            imageA, imageB, colors = datas
            if self.args.mode == 'B2A':
                # swap
                imageA, imageB = imageB, imageA

            self.imageA = imageA.to(self.device)
            self.imageB = imageB.to(self.device)
            colors = colors.to(self.device)

            # run forward propagation. ignore attention
            self.fakeB, _ = self.generator(
                self.imageA,
                colors,
            )

            self._update_discriminator()
            self._update_generator()

            if self.args.verbose and i % self.args.print_every == 0:
                print('%s = %f, %s = %f, %s = %f, %s = %f' % (
                    self.loss_D_real.name,
                    self.loss_D_real(),
                    self.loss_D_fake.name,
                    self.loss_D_fake(),
                    self.loss_G_gan.name,
                    self.loss_G_gan(),
                    self.loss_G_l1.name,
                    self.loss_G_l1(),
                ))

        self.log_file.write('%f\t%f\t%f\t%f\n' %
                            (self.loss_D_real(), self.loss_D_fake(),
                             self.loss_G_gan(), self.loss_G_l1()))
        return i

    def validate(self, dataset, epoch, samples=3):
        # self.generator.eval()
        # self.discriminator.eval()
        length = len(dataset)

        # sample images
        idxs_total = [
            random.sample(range(0, length - 1), samples * 2)
            for _ in range(epoch)
        ]

        for j, idxs in enumerate(idxs_total):
            styles = idxs[samples:]
            targets = idxs[0:samples]

            result = Image.new(
                'RGB', (5 * self.resolution, samples * self.resolution))

            toPIL = transforms.ToPILImage()

            G_loss_gan = []
            G_loss_l1 = []
            D_loss_real = []
            D_loss_fake = []
            l1_loss = self.losses['L1']
            gan_loss = self.losses['GAN']
            for i, (target, style) in enumerate(zip(targets, styles)):
                sub_result = Image.new('RGB',
                                       (5 * self.resolution, self.resolution))
                imageA, imageB, _ = dataset[target]
                styleA, styleB, colors = dataset[style]

                if self.args.mode == 'B2A':
                    imageA, imageB = imageB, imageA
                    styleA, styleB = styleB, styleA

                imageA = imageA.unsqueeze(0).to(self.device)
                imageB = imageB.unsqueeze(0).to(self.device)
                styleB = styleB.unsqueeze(0).to(self.device)
                colors = colors.unsqueeze(0).to(self.device)

                with torch.no_grad():
                    fakeB, _ = self.generator(
                        imageA,
                        colors,
                    )
                    fakeAB = torch.cat([imageA, fakeB], 1)
                    realAB = torch.cat([imageA, imageB], 1)

                    G_loss_l1.append(l1_loss(fakeB, imageB).item())
                    G_loss_gan.append(
                        gan_loss(self.discriminator(fakeAB), True).item())

                    D_loss_real.append(
                        gan_loss(self.discriminator(realAB), True).item())
                    D_loss_fake.append(
                        gan_loss(self.discriminator(fakeAB), False).item())

                styleB = styleB.squeeze()
                fakeB = fakeB.squeeze()
                imageA = imageA.squeeze()
                imageB = imageB.squeeze()
                colors = colors.squeeze()

                imageA = toPIL(re_scale(imageA).detach().cpu())
                imageB = toPIL(re_scale(imageB).detach().cpu())
                styleB = toPIL(re_scale(styleB).detach().cpu())
                fakeB = toPIL(re_scale(fakeB).detach().cpu())

                # synthesize top-4 colors
                color1 = toPIL(re_scale(colors[0:3].detach().cpu()))
                color2 = toPIL(re_scale(colors[3:6].detach().cpu()))
                color3 = toPIL(re_scale(colors[6:9].detach().cpu()))
                color4 = toPIL(re_scale(colors[9:12].detach().cpu()))

                color1 = color1.rotate(90)
                color2 = color2.rotate(90)
                color3 = color3.rotate(90)
                color4 = color4.rotate(90)

                color_result = Image.new('RGB',
                                         (self.resolution, self.resolution))
                color_result.paste(
                    color1.crop((0, 0, self.resolution, self.resolution // 4)),
                    (0, 0))
                color_result.paste(
                    color2.crop((0, 0, self.resolution, self.resolution // 4)),
                    (0, self.resolution // 4))
                color_result.paste(
                    color3.crop((0, 0, self.resolution, self.resolution // 4)),
                    (0, self.resolution // 4 * 2))
                color_result.paste(
                    color4.crop((0, 0, self.resolution, self.resolution // 4)),
                    (0, self.resolution // 4 * 3))

                sub_result.paste(imageA, (0, 0))
                sub_result.paste(styleB, (self.resolution, 0))
                sub_result.paste(fakeB, (2 * self.resolution, 0))
                sub_result.paste(imageB, (3 * self.resolution, 0))
                sub_result.paste(color_result, (4 * self.resolution, 0))

                result.paste(sub_result, (0, 0 + self.resolution * i))

            print(
                'Validate D_loss_real = %f, D_loss_fake = %f, G_loss_l1 = %f, G_loss_gan = %f'
                % (
                    sum(D_loss_real) / samples,
                    sum(D_loss_fake) / samples,
                    sum(G_loss_l1) / samples,
                    sum(G_loss_gan) / samples,
                ))

            save_image(
                result,
                'deepunetpaint_%03d_%02d' % (epoch, j),
                self.save_path,
            )

    def test(self):
        raise NotImplementedError

    def save_model(self, name, epoch):
        save_checkpoints(
            self.generator,
            name + 'G',
            epoch,
            optimizer=self.optimizers['G'],
        )
        save_checkpoints(self.discriminator,
                         name + 'D',
                         epoch,
                         optimizer=self.optimizers['D'])

    def _set_optimizers(self):
        optimG = optim.Adam(self.generator.parameters(),
                            lr=self.args.learning_rate,
                            betas=(self.args.beta1, 0.999))
        optimD = optim.Adam(self.discriminator.parameters(),
                            lr=self.args.learning_rate,
                            betas=(self.args.beta1, 0.999))

        return {'G': optimG, 'D': optimD}

    def _set_losses(self):
        gan_loss = GANLoss(not self.args.no_mse).to(self.device)
        l1_loss = nn.L1Loss().to(self.device)

        return {'GAN': gan_loss, 'L1': l1_loss}

    def _update_generator(self):
        optimG = self.optimizers['G']
        gan_loss = self.losses['GAN']
        l1_loss = self.losses['L1']
        batch_size = self.imageA.shape[0]

        optimG.zero_grad()
        fake_AB = torch.cat([self.imageA, self.fakeB], 1)
        logit_fake = self.discriminator(fake_AB)
        loss_G_gan = gan_loss(logit_fake, True)

        loss_G_l1 = l1_loss(self.fakeB, self.imageB) * self.args.lambd

        self.loss_G_gan.update(loss_G_gan.item(), batch_size)
        self.loss_G_l1.update(loss_G_l1.item(), batch_size)

        loss_G = loss_G_gan + loss_G_l1

        loss_G.backward()
        optimG.step()

    def _update_discriminator(self):
        optimD = self.optimizers['D']
        gan_loss = self.losses['GAN']
        batch_size = self.imageA.shape[0]

        optimD.zero_grad()

        # for real image
        real_AB = torch.cat([self.imageA, self.imageB], 1)
        logit_real = self.discriminator(real_AB)
        loss_D_real = gan_loss(logit_real, True)
        self.loss_D_real.update(loss_D_real.item(), batch_size)

        # for fake image
        fake_AB = torch.cat([self.imageA, self.fakeB], 1)
        fake_AB = self.image_pool(fake_AB)
        logit_fake = self.discriminator(fake_AB.detach())
        loss_D_fake = gan_loss(logit_fake, False)
        self.loss_D_fake.update(loss_D_fake.item(), batch_size)

        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        optimD.step()