Ejemplo n.º 1
0
    def train(self, args):
        # For transforming the input image
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Pytorch dataloader
        a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainA'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainB'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        max_len = max(len(a_loader), len(b_loader))

        steps = 0
        for epoch in range(self.start_epoch, args.epochs):
            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            a_it = iter(a_loader)
            b_it = iter(b_loader)

            for i in range(max_len):
                try:
                    a_real = next(a_it)[0]
                except:
                    a_it = iter(a_loader)

                try:
                    b_real = next(b_it)[0]
                except:
                    b_it = iter(b_loader)

                # Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real)
                b_real = Variable(b_real)
                a_real, b_real = utils.cuda([a_real, b_real])

                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                # Identity losses
                ###################################################
                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # lamda = 1.75E+12
                lamda = args.lamda * args.idt_coef
                a_idt_loss = self.L1(a_idt, a_real) * lamda
                b_idt_loss = self.L1(b_idt, b_real) * lamda

                # a_real_features = vgg.get_features(a_real)
                # b_real_features = vgg.get_features(b_real)
                # a_fake_features = vgg.get_features(a_fake)
                # b_fake_features = vgg.get_features(b_fake)

                # Content losses
                # content_loss_weight = 1.50
                # content_loss_weight = 1
                # a_content_loss = vgg.get_content_loss(b_fake_features, a_real_features) * content_loss_weight
                # b_content_loss = vgg.get_content_loss(a_fake_features, b_real_features) * content_loss_weight

                # style losse
                # style_loss_weight = 3.00E+05
                # style_loss_weight = 1

                # a_style_loss = vgg.get_style_loss(a_fake_features, a_real_features) * style_loss_weight
                # b_style_loss = vgg.get_style_loss(b_fake_features, b_real_features) * style_loss_weight

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                # gen_loss_weight = 4.50E+08
                gen_loss_weight = 1
                a_gen_loss = self.MSE(a_fake_dis, real_label) * gen_loss_weight
                b_gen_loss = self.MSE(b_fake_dis, real_label) * gen_loss_weight

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda
                # lamda = 3.50E+12
                # a_cycle_loss = self.L1(a_recon, a_real) * lamda
                # b_cycle_loss = self.L1(b_recon, b_real) * lamda

                # gen_loss = a_gen_loss + b_gen_loss +\
                #            a_cycle_loss + b_cycle_loss +\
                #            a_style_loss + b_style_loss +\
                #            a_content_loss + b_content_loss +\
                #            a_idt_loss + b_idt_loss

                # # Total generators losses
                # ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss
                # # Update generators
                # ###################################################
                gen_loss.backward()
                self.g_optimizer.step()
                #
                #
                # Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                steps += 1
                if steps % print_msg == 0:
                    print(
                        "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e"
                        % (epoch, i + 1, max(len(a_loader), len(b_loader)),
                           gen_loss, a_dis_loss + b_dis_loss))

            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
    def train(self, args):
        transform = get_transformation((args.crop_height, args.crop_width),
                                       resize=True,
                                       dataset=args.dataset)

        # let the choice of dataset configurable
        if self.args.dataset == 'voc2012':
            labeled_set = VOCDataset(root_path=root,
                                     name='label',
                                     ratio=0.2,
                                     transformation=transform,
                                     augmentation=None)
            unlabeled_set = VOCDataset(root_path=root,
                                       name='unlabel',
                                       ratio=0.2,
                                       transformation=transform,
                                       augmentation=None)
            val_set = VOCDataset(root_path=root,
                                 name='val',
                                 ratio=0.5,
                                 transformation=transform,
                                 augmentation=None)
        elif self.args.dataset == 'cityscapes':
            labeled_set = CityscapesDataset(root_path=root_cityscapes,
                                            name='label',
                                            ratio=0.5,
                                            transformation=transform,
                                            augmentation=None)
            unlabeled_set = CityscapesDataset(root_path=root_cityscapes,
                                              name='unlabel',
                                              ratio=0.5,
                                              transformation=transform,
                                              augmentation=None)
            val_set = CityscapesDataset(root_path=root_cityscapes,
                                        name='val',
                                        ratio=0.5,
                                        transformation=transform,
                                        augmentation=None)
        elif self.args.dataset == 'acdc':
            labeled_set = ACDCDataset(root_path=root_acdc,
                                      name='label',
                                      ratio=0.5,
                                      transformation=transform,
                                      augmentation=None)
            unlabeled_set = ACDCDataset(root_path=root_acdc,
                                        name='unlabel',
                                        ratio=0.5,
                                        transformation=transform,
                                        augmentation=None)
            val_set = ACDCDataset(root_path=root_acdc,
                                  name='val',
                                  ratio=0.5,
                                  transformation=transform,
                                  augmentation=None)
        '''
        https://discuss.pytorch.org/t/about-the-relation-between-batch-size-and-length-of-data-loader/10510
        ^^ The reason for using drop_last=True so as to obtain an even size of all the batches and
        deleting the last batch with less images
        '''
        labeled_loader = DataLoader(labeled_set,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    drop_last=True)
        unlabeled_loader = DataLoader(unlabeled_set,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      drop_last=True)
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=True,
                                drop_last=True)

        new_img_fake_sample = utils.Sample_from_Pool()
        img_fake_sample = utils.Sample_from_Pool()
        gt_fake_sample = utils.Sample_from_Pool()

        img_dis_loss, gt_dis_loss, unsupervisedloss, fullsupervisedloss = 0, 0, 0, 0

        ### Variable to regulate the frequency of update between Discriminators and Generators
        counter = 0

        for epoch in range(self.start_epoch, args.epochs):
            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            self.Gsi.train()
            self.Gis.train()

            # if (epoch+1)%10 == 0:
            # args.lamda_img = args.lamda_img + 0.08
            # args.lamda_gt = args.lamda_gt + 0.04

            for i, ((l_img, l_gt, _),
                    (unl_img, _,
                     _)) in enumerate(zip(labeled_loader, unlabeled_loader)):
                # step
                step = epoch * min(len(labeled_loader),
                                   len(unlabeled_loader)) + i + 1

                l_img, unl_img, l_gt = utils.cuda([l_img, unl_img, l_gt],
                                                  args.gpu_ids)

                # Generator Computations
                ##################################################

                set_grad([self.Di, self.Ds, self.old_Di], False)
                set_grad([self.old_Gsi, self.old_Gis], False)
                self.g_optimizer.zero_grad()

                # Forward pass through generators
                ##################################################
                fake_img = self.Gis(
                    make_one_hot(l_gt, args.dataset, args.gpu_ids).float())
                fake_gt = self.Gsi(unl_img.float())  ### having 21 channels
                lab_gt = self.Gsi(l_img)  ### having 21 channels

                ### Getting the outputs of the model to correct dimensions
                fake_img = self.interp(fake_img)
                fake_gt = self.interp(fake_gt)
                lab_gt = self.interp(lab_gt)

                # fake_gt = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0)  ### will get into no channels
                # fake_gt = fake_gt.unsqueeze(1)   ### will get into 1 channel only
                # fake_gt = make_one_hot(fake_gt, args.dataset, args.gpu_ids)

                lab_loss_CE = self.CE(lab_gt, l_gt.squeeze(1))

                ### Again applying activations
                lab_gt = self.activation_softmax(lab_gt)
                fake_gt = self.activation_softmax(fake_gt)
                # fake_gt = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0)
                # fake_gt = fake_gt.unsqueeze(1)
                # fake_gt = make_one_hot(fake_gt, args.dataset, args.gpu_ids)
                # fake_img = self.activation_tanh(fake_img)

                recon_img = self.Gis(fake_gt.float())
                recon_lab_img = self.Gis(lab_gt.float())
                recon_gt = self.Gsi(fake_img.float())

                ### Getting the outputs of the model to correct dimensions
                recon_img = self.interp(recon_img)
                recon_lab_img = self.interp(recon_lab_img)
                recon_gt = self.interp(recon_gt)

                ### This is for the case of the new loss between the recon_img from resnet and deeplab network
                resnet_fake_gt = self.old_Gsi(unl_img.float())
                resnet_lab_gt = self.old_Gsi(l_img)
                resnet_lab_gt = self.activation_softmax(resnet_lab_gt)
                resnet_fake_gt = self.activation_softmax(resnet_fake_gt)
                resnet_recon_img = self.old_Gis(resnet_fake_gt.float())
                resnet_recon_lab_img = self.old_Gis(resnet_lab_gt.float())

                ## Applying the tanh activations
                # recon_img = self.activation_tanh(recon_img)
                # recon_lab_img = self.activation_tanh(recon_lab_img)

                # Adversarial losses
                ###################################################
                fake_img_dis = self.Di(fake_img)
                resnet_fake_img_dis = self.old_Di(recon_img)

                ### For passing different type of input to Ds
                fake_gt_discriminator = fake_gt.data.max(1)[1].squeeze_(
                    1).squeeze_(0)
                fake_gt_discriminator = fake_gt_discriminator.unsqueeze(1)
                fake_gt_discriminator = make_one_hot(fake_gt_discriminator,
                                                     args.dataset,
                                                     args.gpu_ids)
                fake_gt_dis = self.Ds(fake_gt_discriminator.float())
                # lab_gt_dis = self.Ds(lab_gt)

                real_label_gt = utils.cuda(
                    Variable(torch.ones(fake_gt_dis.size())), args.gpu_ids)
                real_label_img = utils.cuda(
                    Variable(torch.ones(fake_img_dis.size())), args.gpu_ids)

                # here is much better to have a cross entropy loss for classification.
                img_gen_loss = self.MSE(fake_img_dis, real_label_img)
                gt_gen_loss = self.MSE(fake_gt_dis, real_label_gt)
                # gt_label_gen_loss = self.MSE(lab_gt_dis, real_label)

                # Cycle consistency losses
                ###################################################
                resnet_img_cycle_loss = self.MSE(resnet_fake_img_dis,
                                                 real_label_img)
                # img_cycle_loss = self.L1(recon_img, unl_img)
                # img_cycle_loss_perceptual = perceptual_loss(recon_img, unl_img, args.gpu_ids)
                gt_cycle_loss = self.CE(recon_gt, l_gt.squeeze(1))
                # lab_img_cycle_loss = self.L1(recon_lab_img, l_img) * args.lamda

                # Total generators losses
                ###################################################
                # lab_loss_CE = self.CE(lab_gt, l_gt.squeeze(1))
                lab_loss_MSE = self.L1(fake_img, l_img)
                # lab_loss_perceptual = perceptual_loss(fake_img, l_img, args.gpu_ids)

                fullsupervisedloss = args.lab_CE_weight * lab_loss_CE + args.lab_MSE_weight * lab_loss_MSE

                unsupervisedloss = args.adversarial_weight * (
                    img_gen_loss + gt_gen_loss
                ) + resnet_img_cycle_loss + gt_cycle_loss * args.lamda_gt

                gen_loss = fullsupervisedloss + unsupervisedloss

                # Update generators
                ###################################################
                gen_loss.backward()

                self.g_optimizer.step()

                if counter % 1 == 0:
                    # Discriminator Computations
                    #################################################

                    set_grad([self.Di, self.Ds, self.old_Di], True)
                    self.d_optimizer.zero_grad()

                    # Sample from history of generated images
                    #################################################
                    if torch.rand(1) < 0.0:
                        fake_img = self.gauss_noise(fake_img.cpu())
                        fake_gt = self.gauss_noise(fake_gt.cpu())

                    recon_img = Variable(
                        torch.Tensor(
                            new_img_fake_sample([recon_img.cpu().data.numpy()
                                                 ])[0]))
                    fake_img = Variable(
                        torch.Tensor(
                            img_fake_sample([fake_img.cpu().data.numpy()])[0]))
                    # lab_gt = Variable(torch.Tensor(gt_fake_sample([lab_gt.cpu().data.numpy()])[0]))
                    fake_gt = Variable(
                        torch.Tensor(
                            gt_fake_sample([fake_gt.cpu().data.numpy()])[0]))

                    recon_img, fake_img, fake_gt = utils.cuda(
                        [recon_img, fake_img, fake_gt], args.gpu_ids)

                    # Forward pass through discriminators
                    #################################################
                    unl_img_dis = self.Di(unl_img)
                    fake_img_dis = self.Di(fake_img)
                    resnet_recon_img_dis = self.old_Di(resnet_recon_img)
                    resnet_fake_img_dis = self.old_Di(recon_img)

                    # lab_gt_dis = self.Ds(lab_gt)

                    l_gt = make_one_hot(l_gt, args.dataset, args.gpu_ids)
                    real_gt_dis = self.Ds(l_gt.float())

                    fake_gt_discriminator = fake_gt.data.max(1)[1].squeeze_(
                        1).squeeze_(0)
                    fake_gt_discriminator = fake_gt_discriminator.unsqueeze(1)
                    fake_gt_discriminator = make_one_hot(
                        fake_gt_discriminator, args.dataset, args.gpu_ids)
                    fake_gt_dis = self.Ds(fake_gt_discriminator.float())

                    real_label_img = utils.cuda(
                        Variable(torch.ones(unl_img_dis.size())), args.gpu_ids)
                    fake_label_img = utils.cuda(
                        Variable(torch.zeros(fake_img_dis.size())),
                        args.gpu_ids)
                    real_label_gt = utils.cuda(
                        Variable(torch.ones(real_gt_dis.size())), args.gpu_ids)
                    fake_label_gt = utils.cuda(
                        Variable(torch.zeros(fake_gt_dis.size())),
                        args.gpu_ids)

                    # Discriminator losses
                    ##################################################
                    img_dis_real_loss = self.MSE(unl_img_dis, real_label_img)
                    img_dis_fake_loss = self.MSE(fake_img_dis, fake_label_img)
                    gt_dis_real_loss = self.MSE(real_gt_dis, real_label_gt)
                    gt_dis_fake_loss = self.MSE(fake_gt_dis, fake_label_gt)
                    # lab_gt_dis_fake_loss = self.MSE(lab_gt_dis, fake_label)

                    cycle_img_dis_real_loss = self.MSE(resnet_recon_img_dis,
                                                       real_label_img)
                    cycle_img_dis_fake_loss = self.MSE(resnet_fake_img_dis,
                                                       fake_label_img)

                    # Total discriminators losses
                    img_dis_loss = (img_dis_real_loss +
                                    img_dis_fake_loss) * 0.5
                    gt_dis_loss = (gt_dis_real_loss + gt_dis_fake_loss) * 0.5
                    # lab_gt_dis_loss = (gt_dis_real_loss + lab_gt_dis_fake_loss)*0.33
                    cycle_img_dis_loss = cycle_img_dis_real_loss + cycle_img_dis_fake_loss

                    # Update discriminators
                    ##################################################
                    discriminator_loss = args.discriminator_weight * (
                        img_dis_loss + gt_dis_loss) + cycle_img_dis_loss
                    discriminator_loss.backward()

                    # lab_gt_dis_loss.backward()
                    self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Dis Loss:%.2e | Unlab Gen Loss:%.2e | Lab Gen loss:%.2e"
                    % (epoch, i + 1,
                       min(len(labeled_loader),
                           len(unlabeled_loader)), img_dis_loss + gt_dis_loss,
                       unsupervisedloss, fullsupervisedloss))

                self.writer_semisuper.add_scalars(
                    'Dis Loss', {
                        'img_dis_loss': img_dis_loss,
                        'gt_dis_loss': gt_dis_loss,
                        'cycle_img_dis_loss': cycle_img_dis_loss
                    },
                    len(labeled_loader) * epoch + i)
                self.writer_semisuper.add_scalars(
                    'Unlabelled Loss', {
                        'img_gen_loss': img_gen_loss,
                        'gt_gen_loss': gt_gen_loss,
                        'img_cycle_loss': resnet_img_cycle_loss,
                        'gt_cycle_loss': gt_cycle_loss
                    },
                    len(labeled_loader) * epoch + i)
                self.writer_semisuper.add_scalars(
                    'Labelled Loss', {
                        'lab_loss_CE': lab_loss_CE,
                        'lab_loss_MSE': lab_loss_MSE
                    },
                    len(labeled_loader) * epoch + i)

                counter += 1

            ### For getting the mean IoU
            self.Gsi.eval()
            self.Gis.eval()
            with torch.no_grad():
                for i, (val_img, val_gt, _) in enumerate(val_loader):
                    val_img, val_gt = utils.cuda([val_img, val_gt],
                                                 args.gpu_ids)

                    outputs = self.Gsi(val_img)
                    outputs = self.interp(outputs)
                    outputs = self.activation_softmax(outputs)

                    pred = outputs.data.max(1)[1].cpu().numpy()
                    gt = val_gt.squeeze().data.cpu().numpy()

                    self.running_metrics_val.update(gt, pred)

            score, class_iou = self.running_metrics_val.get_scores()

            self.running_metrics_val.reset()

            print('The mIoU for the epoch is: ', score["Mean IoU : \t"])

            ### For displaying the images generated by generator on tensorboard using validation images
            val_image, val_gt, _ = iter(val_loader).next()
            val_image, val_gt = utils.cuda([val_image, val_gt], args.gpu_ids)
            with torch.no_grad():
                fake_label = self.Gsi(val_image).detach()
                fake_label = self.interp(fake_label)
                fake_label = self.activation_softmax(fake_label)
                fake_label = fake_label.data.max(1)[1].squeeze_(1).squeeze_(0)
                fake_label = fake_label.unsqueeze(1)
                fake_label = make_one_hot(fake_label, args.dataset,
                                          args.gpu_ids)
                fake_img = self.Gis(fake_label).detach()
                fake_img = self.interp(fake_img)
                # fake_img = self.activation_tanh(fake_img)

                fake_img_from_labels = self.Gis(
                    make_one_hot(val_gt, args.dataset,
                                 args.gpu_ids).float()).detach()
                fake_img_from_labels = self.interp(fake_img_from_labels)
                # fake_img_from_labels = self.activation_tanh(fake_img_from_labels)
                fake_label_regenerated = self.Gsi(
                    fake_img_from_labels).detach()
                fake_label_regenerated = self.interp(fake_label_regenerated)
                fake_label_regenerated = self.activation_softmax(
                    fake_label_regenerated)
            fake_prediction_label = fake_label.data.max(1)[1].squeeze_(
                1).cpu().numpy()
            fake_regenerated_label = fake_label_regenerated.data.max(
                1)[1].squeeze_(1).cpu().numpy()
            val_gt = val_gt.cpu()

            fake_img = fake_img.cpu()
            fake_img_from_labels = fake_img_from_labels.cpu()
            ### Now i am going to revert back the transformation on these images
            if self.args.dataset == 'voc2012' or self.args.dataset == 'cityscapes':
                trans_mean = [0.5, 0.5, 0.5]
                trans_std = [0.5, 0.5, 0.5]
                for i in range(3):
                    fake_img[:, i, :, :] = (
                        (fake_img[:, i, :, :] * trans_std[i]) + trans_mean[i])
                    fake_img_from_labels[:, i, :, :] = (
                        (fake_img_from_labels[:, i, :, :] * trans_std[i]) +
                        trans_mean[i])

            elif self.args.dataset == 'acdc':
                trans_mean = [0.5]
                trans_std = [0.5]
                for i in range(1):
                    fake_img[:, i, :, :] = (
                        (fake_img[:, i, :, :] * trans_std[i]) + trans_mean[i])
                    fake_img_from_labels[:, i, :, :] = (
                        (fake_img_from_labels[:, i, :, :] * trans_std[i]) +
                        trans_mean[i])

            ### display_tensor is the final tensor that will be displayed on tensorboard
            display_tensor_label = torch.zeros([
                fake_label.shape[0], 3, fake_label.shape[2],
                fake_label.shape[3]
            ])
            display_tensor_gt = torch.zeros(
                [val_gt.shape[0], 3, val_gt.shape[2], val_gt.shape[3]])
            display_tensor_regen_label = torch.zeros([
                fake_label_regenerated.shape[0], 3,
                fake_label_regenerated.shape[2],
                fake_label_regenerated.shape[3]
            ])
            for i in range(fake_prediction_label.shape[0]):
                new_img_label = fake_prediction_label[i]
                new_img_label = utils.colorize_mask(
                    new_img_label, self.args.dataset
                )  ### So this is the generated image in PIL.Image format
                img_tensor_label = utils.PIL_to_tensor(new_img_label,
                                                       self.args.dataset)
                display_tensor_label[i, :, :, :] = img_tensor_label

                display_tensor_gt[i, :, :, :] = val_gt[i]

                regen_label = fake_regenerated_label[i]
                regen_label = utils.colorize_mask(regen_label,
                                                  self.args.dataset)
                regen_tensor_label = utils.PIL_to_tensor(
                    regen_label, self.args.dataset)
                display_tensor_regen_label[i, :, :, :] = regen_tensor_label

            self.writer_semisuper.add_image(
                'Generated segmented image: ',
                torchvision.utils.make_grid(display_tensor_label,
                                            nrow=2,
                                            normalize=True), epoch)
            self.writer_semisuper.add_image(
                'Generated image back from segmentation: ',
                torchvision.utils.make_grid(fake_img, nrow=2, normalize=True),
                epoch)
            self.writer_semisuper.add_image(
                'Ground truth for the image: ',
                torchvision.utils.make_grid(display_tensor_gt,
                                            nrow=2,
                                            normalize=True), epoch)
            self.writer_semisuper.add_image(
                'Image generated from val labels: ',
                torchvision.utils.make_grid(fake_img_from_labels,
                                            nrow=2,
                                            normalize=True), epoch)
            self.writer_semisuper.add_image(
                'Labels generated back from the cycle: ',
                torchvision.utils.make_grid(display_tensor_regen_label,
                                            nrow=2,
                                            normalize=True), epoch)

            if score["Mean IoU : \t"] >= self.best_iou:
                self.best_iou = score["Mean IoU : \t"]

                # Override the latest checkpoint
                #######################################################
                utils.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'Di': self.Di.state_dict(),
                        'Ds': self.Ds.state_dict(),
                        'Gis': self.Gis.state_dict(),
                        'Gsi': self.Gsi.state_dict(),
                        'd_optimizer': self.d_optimizer.state_dict(),
                        'g_optimizer': self.g_optimizer.state_dict(),
                        'best_iou': self.best_iou,
                        'class_iou': class_iou
                    }, '%s/latest_semisuper_cycleGAN.ckpt' %
                    (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()

        self.writer_semisuper.close()
Ejemplo n.º 3
0
    def train(self, args):
        # Test input
        transform_test = transforms.Compose([
            transforms.Resize((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_testdata_link(args.dataset_dir)

        a_test_data = dsets.ImageFolder(dataset_dirs['testA'],
                                        transform=transform_test)
        b_test_data = dsets.ImageFolder(dataset_dirs['testB'],
                                        transform=transform_test)

        a_test_loader = torch.utils.data.DataLoader(a_test_data,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=4)
        b_test_loader = torch.utils.data.DataLoader(b_test_data,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=4)

        # For transforming the input image
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Pytorch dataloader
        a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainA'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainB'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                # step
                step = epoch * min(len(a_loader), len(b_loader)) + i + 1

                # Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity losses
                ###################################################
                a_idt_loss = self.L1(a_idt,
                                     a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.L1(b_idt,
                                     b_real) * args.lamda * args.idt_coef

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda

                # Total generators losses
                ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                ###################################################
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %
                    (epoch, i + 1, min(len(a_loader), len(b_loader)), gen_loss,
                     a_dis_loss + b_dis_loss))

            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))
            # Save image current :
            #######################################################################
            """ run """
            a_real_test = Variable(iter(a_test_loader).next()[0],
                                   requires_grad=True)
            b_real_test = Variable(iter(b_test_loader).next()[0],
                                   requires_grad=True)
            a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])

            self.Gab.eval()
            self.Gba.eval()

            with torch.no_grad():
                a_fake_test = self.Gab(b_real_test)
                b_fake_test = self.Gba(a_real_test)
                a_recon_test = self.Gab(b_fake_test)
                b_recon_test = self.Gba(a_fake_test)

            pic = (torch.cat([
                a_real_test, b_fake_test, a_recon_test, b_real_test,
                a_fake_test, b_recon_test
            ],
                             dim=0).data + 1) / 2.0

            if not os.path.isdir(args.results_dir):
                os.makedirs(args.results_dir)

            torchvision.utils.save_image(pic,
                                         args.results_dir +
                                         '/sample_{}.jpg'.format(epoch),
                                         nrow=3)

            self.Gab.train()
            self.Gba.train()
            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
Ejemplo n.º 4
0
    def train(self, args):
        # For transforming the input image
        transform = transforms.Compose([
            # [transforms.RandomHorizontalFlip(),
            transforms.Resize((480, 1440)),
            #  transforms.RandomCrop((args.crop_height,args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Pytorch dataloader
        dataset = torch.utils.data.DataLoader(dsets.ImageFolder(
            '/train_merged/', transform=transform),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=4)
        # a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainA'], transform=transform),
        #                                                 batch_size=args.batch_size, shuffle=True, num_workers=4)
        # b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainB'], transform=transform),
        #                                                 batch_size=args.batch_size, shuffle=True, num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, x in enumerate(dataset):
                # step
                step = epoch * len(dataset) + i + 1

                # Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                x = Variable(x[0])
                x = utils.cuda([x])[0]
                shape = x.shape
                a_real, b_real = x[:, :, :, shape[3] // 2:], x[:, :, :,
                                                               shape[3] // 2:]
                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity losses
                ###################################################
                a_idt_loss = self.L1(a_idt,
                                     a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.L1(b_idt,
                                     b_real) * args.lamda * args.idt_coef

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda

                # Total generators losses
                ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                ###################################################
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %
                    (epoch, i + 1, len(dataset), gen_loss,
                     a_dis_loss + b_dis_loss))

            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
Ejemplo n.º 5
0
    def train(self, args):
        # For transforming the input image
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            # transforms.Resize((args.load_height,args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)
        dataset_a = ListDataSet(
            '/media/l/新加卷1/city/data/river/train_256_9w.lst',
            transform=transform)
        dataset_b = ListDataSet('/media/l/新加卷/city/jinan_z3.lst',
                                transform=transform)
        # Pytorch dataloader
        a_loader = torch.utils.data.DataLoader(dataset_a,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dataset_b,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                # step
                step = epoch * min(len(a_loader), len(b_loader)) + i + 1

                # Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity losses
                ###################################################
                a_idt_loss = self.L1(a_idt,
                                     a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.L1(b_idt,
                                     b_real) * args.lamda * args.idt_coef

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda

                # Total generators losses
                ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                ###################################################
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.4f | Dis Loss:%.4f" %
                    (epoch, i + 1, min(len(a_loader), len(b_loader)), gen_loss,
                     a_dis_loss + b_dis_loss))

            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict(),
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()