示例#1
0
    def train(self,args):
        # For transforming the input image
        transform = transforms.Compose(
            [transforms.RandomHorizontalFlip(),
             transforms.Resize((args.img_height,args.img_width)),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Pytorch dataloader
        a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainA'], transform=transform), 
                                                                    batch_size=args.batch_size, shuffle=True, num_workers=4)
        b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainB'], transform=transform), 
                                                                    batch_size=args.batch_size, shuffle=True, num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):
            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                # step
                step = epoch * min(len(a_loader), len(b_loader)) + i + 1

                # set train
                self.Gab.train()
                self.Gba.train()

                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                # Forward pass through generators
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                # Adversarial losses
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(a_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                # Cycle consistency losses
                a_cycle_loss = self.L1(a_recon, a_real)
                b_cycle_loss = self.L1(b_recon, b_real)

                # Total generators losses
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss * args.lamda + b_cycle_loss * args.lamda

                # Update generators
                self.Gab.zero_grad()
                self.Gba.zero_grad()
                gen_loss.backward()
                self.gab_optimizer.step()
                self.gba_optimizer.step()

                # Sample from history of generated images
                a_fake = Variable(torch.Tensor(a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(torch.Tensor(b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators 
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(a_real_dis.size())))
                fake_label = utils.cuda(Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = a_dis_real_loss + a_dis_fake_loss
                b_dis_loss = b_dis_real_loss + b_dis_fake_loss

                # Update discriminators
                self.Da.zero_grad()
                self.Db.zero_grad()
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.da_optimizer.step()
                self.db_optimizer.step()

                print("Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" % 
                                            (epoch, i + 1, min(len(a_loader), len(b_loader)),
                                                            gen_loss,a_dis_loss+b_dis_loss))

            # Override the latest checkpoint 
            utils.save_checkpoint({'epoch': epoch + 1,
                                   'Da': self.Da.state_dict(),
                                   'Db': self.Db.state_dict(),
                                   'Gab': self.Gab.state_dict(),
                                   'Gba': self.Gba.state_dict(),
                                   'da_optimizer': self.da_optimizer.state_dict(),
                                   'db_optimizer': self.db_optimizer.state_dict(),
                                   'gab_optimizer': self.gab_optimizer.state_dict(),
                                   'gba_optimizer': self.gba_optimizer.state_dict()},
                                  '%s/latest.ckpt' % (args.checkpoint_dir))
示例#2
0
    def train(self, args):
        # For transforming the input image
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Pytorch dataloader
        a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainA'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainB'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        max_len = max(len(a_loader), len(b_loader))

        steps = 0
        for epoch in range(self.start_epoch, args.epochs):
            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            a_it = iter(a_loader)
            b_it = iter(b_loader)

            for i in range(max_len):
                try:
                    a_real = next(a_it)[0]
                except:
                    a_it = iter(a_loader)

                try:
                    b_real = next(b_it)[0]
                except:
                    b_it = iter(b_loader)

                # Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real)
                b_real = Variable(b_real)
                a_real, b_real = utils.cuda([a_real, b_real])

                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                # Identity losses
                ###################################################
                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # lamda = 1.75E+12
                lamda = args.lamda * args.idt_coef
                a_idt_loss = self.L1(a_idt, a_real) * lamda
                b_idt_loss = self.L1(b_idt, b_real) * lamda

                # a_real_features = vgg.get_features(a_real)
                # b_real_features = vgg.get_features(b_real)
                # a_fake_features = vgg.get_features(a_fake)
                # b_fake_features = vgg.get_features(b_fake)

                # Content losses
                # content_loss_weight = 1.50
                # content_loss_weight = 1
                # a_content_loss = vgg.get_content_loss(b_fake_features, a_real_features) * content_loss_weight
                # b_content_loss = vgg.get_content_loss(a_fake_features, b_real_features) * content_loss_weight

                # style losse
                # style_loss_weight = 3.00E+05
                # style_loss_weight = 1

                # a_style_loss = vgg.get_style_loss(a_fake_features, a_real_features) * style_loss_weight
                # b_style_loss = vgg.get_style_loss(b_fake_features, b_real_features) * style_loss_weight

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                # gen_loss_weight = 4.50E+08
                gen_loss_weight = 1
                a_gen_loss = self.MSE(a_fake_dis, real_label) * gen_loss_weight
                b_gen_loss = self.MSE(b_fake_dis, real_label) * gen_loss_weight

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda
                # lamda = 3.50E+12
                # a_cycle_loss = self.L1(a_recon, a_real) * lamda
                # b_cycle_loss = self.L1(b_recon, b_real) * lamda

                # gen_loss = a_gen_loss + b_gen_loss +\
                #            a_cycle_loss + b_cycle_loss +\
                #            a_style_loss + b_style_loss +\
                #            a_content_loss + b_content_loss +\
                #            a_idt_loss + b_idt_loss

                # # Total generators losses
                # ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss
                # # Update generators
                # ###################################################
                gen_loss.backward()
                self.g_optimizer.step()
                #
                #
                # Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                steps += 1
                if steps % print_msg == 0:
                    print(
                        "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e"
                        % (epoch, i + 1, max(len(a_loader), len(b_loader)),
                           gen_loss, a_dis_loss + b_dis_loss))

            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
    def train(self, args):

        transform = get_transformation(
            (self.args.crop_height, self.args.crop_width),
            resize=True,
            dataset=args.dataset)
        val_transform = get_transformation((512, 512),
                                           resize=True,
                                           dataset=args.dataset)

        # let the choice of dataset configurable
        if self.args.dataset == 'voc2012':
            labeled_set = VOCDataset(root_path=root,
                                     name='label',
                                     ratio=1.0,
                                     transformation=transform,
                                     augmentation=None)
            val_set = VOCDataset(root_path=root,
                                 name='val',
                                 ratio=0.5,
                                 transformation=val_transform,
                                 augmentation=None)
            labeled_loader = DataLoader(labeled_set,
                                        batch_size=self.args.batch_size,
                                        shuffle=True,
                                        drop_last=True)
            val_loader = DataLoader(val_set,
                                    batch_size=self.args.batch_size,
                                    shuffle=True)
        elif self.args.dataset == 'cityscapes':
            labeled_set = CityscapesDataset(root_path=root_cityscapes,
                                            name='label',
                                            ratio=0.5,
                                            transformation=transform,
                                            augmentation=None)
            val_set = CityscapesDataset(root_path=root_cityscapes,
                                        name='val',
                                        ratio=0.5,
                                        transformation=transform,
                                        augmentation=None)
            labeled_loader = DataLoader(labeled_set,
                                        batch_size=self.args.batch_size,
                                        shuffle=True,
                                        drop_last=True)
            val_loader = DataLoader(val_set,
                                    batch_size=self.args.batch_size,
                                    shuffle=True,
                                    drop_last=True)
        elif self.args.dataset == 'acdc':
            labeled_set = ACDCDataset(root_path=root_acdc,
                                      name='label',
                                      ratio=0.5,
                                      transformation=transform,
                                      augmentation=None)
            val_set = ACDCDataset(root_path=root_acdc,
                                  name='val',
                                  ratio=0.5,
                                  transformation=transform,
                                  augmentation=None)
            labeled_loader = DataLoader(labeled_set,
                                        batch_size=self.args.batch_size,
                                        shuffle=True,
                                        drop_last=True)
            val_loader = DataLoader(val_set,
                                    batch_size=self.args.batch_size,
                                    shuffle=True,
                                    drop_last=True)

        img_fake_sample = utils.Sample_from_Pool()
        gt_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, self.args.epochs):
            self.Gsi.train()
            for i, (l_img, l_gt, img_name) in enumerate(labeled_loader):
                # step
                step = epoch * len(labeled_loader) + i + 1

                self.gsi_optimizer.zero_grad()

                l_img, l_gt = utils.cuda([l_img, l_gt], args.gpu_ids)

                lab_gt = self.Gsi(l_img)

                lab_gt = self.interp(
                    lab_gt)  ### To get the output of model same as labels

                # CE losses
                fullsupervisedloss = self.CE(lab_gt, l_gt.squeeze(1))

                fullsupervisedloss.backward()
                self.gsi_optimizer.step()

                print("Epoch: (%3d) (%5d/%5d) | Crossentropy Loss:%.2e" %
                      (epoch, i + 1, len(labeled_loader),
                       fullsupervisedloss.item()))

                self.writer_supervised.add_scalars(
                    'Supervised Loss', {'CE Loss ': fullsupervisedloss},
                    len(labeled_loader) * epoch + i)

            ### For getting the IoU for the image
            self.Gsi.eval()
            with torch.no_grad():
                for i, (val_img, val_gt, _) in enumerate(val_loader):
                    val_img, val_gt = utils.cuda([val_img, val_gt],
                                                 args.gpu_ids)

                    outputs = self.Gsi(val_img)
                    outputs = self.interp_val(outputs)
                    outputs = self.activation_softmax(outputs)

                    pred = outputs.data.max(1)[1].cpu().numpy()
                    gt = val_gt.squeeze().data.cpu().numpy()

                    self.running_metrics_val.update(gt, pred)

            score, class_iou = self.running_metrics_val.get_scores()

            self.running_metrics_val.reset()

            ### For displaying the images generated by generator on tensorboard
            val_img, val_gt, _ = iter(val_loader).next()
            val_img, val_gt = utils.cuda([val_img, val_gt], args.gpu_ids)
            with torch.no_grad():
                fake = self.Gsi(val_img).detach()
                fake = self.interp_val(fake)
            fake = self.activation_softmax(fake)
            fake_prediction = fake.data.max(1)[1].squeeze_(1).squeeze_(
                0).cpu().numpy()
            val_gt = val_gt.cpu()

            ### display_tensor is the final tensor that will be displayed on tensorboard
            display_tensor = torch.zeros(
                [fake.shape[0], 3, fake.shape[2], fake.shape[3]])
            display_tensor_gt = torch.zeros(
                [val_gt.shape[0], 3, val_gt.shape[2], val_gt.shape[3]])
            for i in range(fake_prediction.shape[0]):
                new_img = fake_prediction[i]
                new_img = utils.colorize_mask(
                    new_img, self.args.dataset
                )  ### So this is the generated image in PIL.Image format
                img_tensor = utils.PIL_to_tensor(new_img, self.args.dataset)
                display_tensor[i, :, :, :] = img_tensor

                display_tensor_gt[i, :, :, :] = val_gt[i]

            self.writer_supervised.add_image(
                'Generated segmented image',
                torchvision.utils.make_grid(display_tensor,
                                            nrow=2,
                                            normalize=True), epoch)
            self.writer_supervised.add_image(
                'Ground truth for the image',
                torchvision.utils.make_grid(display_tensor_gt,
                                            nrow=2,
                                            normalize=True), epoch)

            if score["Mean IoU : \t"] >= self.best_iou:
                self.best_iou = score["Mean IoU : \t"]
                # Override the latest checkpoint
                utils.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'Gsi': self.Gsi.state_dict(),
                        'gsi_optimizer': self.gsi_optimizer.state_dict(),
                        'best_iou': self.best_iou,
                        'class_iou': class_iou
                    }, '%s/latest_supervised_model.ckpt' %
                    (self.args.checkpoint_dir))

        self.writer_supervised.close()
    def train(self, args):
        transform = get_transformation((args.crop_height, args.crop_width),
                                       resize=True,
                                       dataset=args.dataset)

        # let the choice of dataset configurable
        if self.args.dataset == 'voc2012':
            labeled_set = VOCDataset(root_path=root,
                                     name='label',
                                     ratio=0.2,
                                     transformation=transform,
                                     augmentation=None)
            unlabeled_set = VOCDataset(root_path=root,
                                       name='unlabel',
                                       ratio=0.2,
                                       transformation=transform,
                                       augmentation=None)
            val_set = VOCDataset(root_path=root,
                                 name='val',
                                 ratio=0.5,
                                 transformation=transform,
                                 augmentation=None)
        elif self.args.dataset == 'cityscapes':
            labeled_set = CityscapesDataset(root_path=root_cityscapes,
                                            name='label',
                                            ratio=0.5,
                                            transformation=transform,
                                            augmentation=None)
            unlabeled_set = CityscapesDataset(root_path=root_cityscapes,
                                              name='unlabel',
                                              ratio=0.5,
                                              transformation=transform,
                                              augmentation=None)
            val_set = CityscapesDataset(root_path=root_cityscapes,
                                        name='val',
                                        ratio=0.5,
                                        transformation=transform,
                                        augmentation=None)
        elif self.args.dataset == 'acdc':
            labeled_set = ACDCDataset(root_path=root_acdc,
                                      name='label',
                                      ratio=0.5,
                                      transformation=transform,
                                      augmentation=None)
            unlabeled_set = ACDCDataset(root_path=root_acdc,
                                        name='unlabel',
                                        ratio=0.5,
                                        transformation=transform,
                                        augmentation=None)
            val_set = ACDCDataset(root_path=root_acdc,
                                  name='val',
                                  ratio=0.5,
                                  transformation=transform,
                                  augmentation=None)
        '''
        https://discuss.pytorch.org/t/about-the-relation-between-batch-size-and-length-of-data-loader/10510
        ^^ The reason for using drop_last=True so as to obtain an even size of all the batches and
        deleting the last batch with less images
        '''
        labeled_loader = DataLoader(labeled_set,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    drop_last=True)
        unlabeled_loader = DataLoader(unlabeled_set,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      drop_last=True)
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=True,
                                drop_last=True)

        new_img_fake_sample = utils.Sample_from_Pool()
        img_fake_sample = utils.Sample_from_Pool()
        gt_fake_sample = utils.Sample_from_Pool()

        img_dis_loss, gt_dis_loss, unsupervisedloss, fullsupervisedloss = 0, 0, 0, 0

        ### Variable to regulate the frequency of update between Discriminators and Generators
        counter = 0

        for epoch in range(self.start_epoch, args.epochs):
            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            self.Gsi.train()
            self.Gis.train()

            # if (epoch+1)%10 == 0:
            # args.lamda_img = args.lamda_img + 0.08
            # args.lamda_gt = args.lamda_gt + 0.04

            for i, ((l_img, l_gt, _),
                    (unl_img, _,
                     _)) in enumerate(zip(labeled_loader, unlabeled_loader)):
                # step
                step = epoch * min(len(labeled_loader),
                                   len(unlabeled_loader)) + i + 1

                l_img, unl_img, l_gt = utils.cuda([l_img, unl_img, l_gt],
                                                  args.gpu_ids)

                # Generator Computations
                ##################################################

                set_grad([self.Di, self.Ds, self.old_Di], False)
                set_grad([self.old_Gsi, self.old_Gis], False)
                self.g_optimizer.zero_grad()

                # Forward pass through generators
                ##################################################
                fake_img = self.Gis(
                    make_one_hot(l_gt, args.dataset, args.gpu_ids).float())
                fake_gt = self.Gsi(unl_img.float())  ### having 21 channels
                lab_gt = self.Gsi(l_img)  ### having 21 channels

                ### Getting the outputs of the model to correct dimensions
                fake_img = self.interp(fake_img)
                fake_gt = self.interp(fake_gt)
                lab_gt = self.interp(lab_gt)

                # fake_gt = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0)  ### will get into no channels
                # fake_gt = fake_gt.unsqueeze(1)   ### will get into 1 channel only
                # fake_gt = make_one_hot(fake_gt, args.dataset, args.gpu_ids)

                lab_loss_CE = self.CE(lab_gt, l_gt.squeeze(1))

                ### Again applying activations
                lab_gt = self.activation_softmax(lab_gt)
                fake_gt = self.activation_softmax(fake_gt)
                # fake_gt = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0)
                # fake_gt = fake_gt.unsqueeze(1)
                # fake_gt = make_one_hot(fake_gt, args.dataset, args.gpu_ids)
                # fake_img = self.activation_tanh(fake_img)

                recon_img = self.Gis(fake_gt.float())
                recon_lab_img = self.Gis(lab_gt.float())
                recon_gt = self.Gsi(fake_img.float())

                ### Getting the outputs of the model to correct dimensions
                recon_img = self.interp(recon_img)
                recon_lab_img = self.interp(recon_lab_img)
                recon_gt = self.interp(recon_gt)

                ### This is for the case of the new loss between the recon_img from resnet and deeplab network
                resnet_fake_gt = self.old_Gsi(unl_img.float())
                resnet_lab_gt = self.old_Gsi(l_img)
                resnet_lab_gt = self.activation_softmax(resnet_lab_gt)
                resnet_fake_gt = self.activation_softmax(resnet_fake_gt)
                resnet_recon_img = self.old_Gis(resnet_fake_gt.float())
                resnet_recon_lab_img = self.old_Gis(resnet_lab_gt.float())

                ## Applying the tanh activations
                # recon_img = self.activation_tanh(recon_img)
                # recon_lab_img = self.activation_tanh(recon_lab_img)

                # Adversarial losses
                ###################################################
                fake_img_dis = self.Di(fake_img)
                resnet_fake_img_dis = self.old_Di(recon_img)

                ### For passing different type of input to Ds
                fake_gt_discriminator = fake_gt.data.max(1)[1].squeeze_(
                    1).squeeze_(0)
                fake_gt_discriminator = fake_gt_discriminator.unsqueeze(1)
                fake_gt_discriminator = make_one_hot(fake_gt_discriminator,
                                                     args.dataset,
                                                     args.gpu_ids)
                fake_gt_dis = self.Ds(fake_gt_discriminator.float())
                # lab_gt_dis = self.Ds(lab_gt)

                real_label_gt = utils.cuda(
                    Variable(torch.ones(fake_gt_dis.size())), args.gpu_ids)
                real_label_img = utils.cuda(
                    Variable(torch.ones(fake_img_dis.size())), args.gpu_ids)

                # here is much better to have a cross entropy loss for classification.
                img_gen_loss = self.MSE(fake_img_dis, real_label_img)
                gt_gen_loss = self.MSE(fake_gt_dis, real_label_gt)
                # gt_label_gen_loss = self.MSE(lab_gt_dis, real_label)

                # Cycle consistency losses
                ###################################################
                resnet_img_cycle_loss = self.MSE(resnet_fake_img_dis,
                                                 real_label_img)
                # img_cycle_loss = self.L1(recon_img, unl_img)
                # img_cycle_loss_perceptual = perceptual_loss(recon_img, unl_img, args.gpu_ids)
                gt_cycle_loss = self.CE(recon_gt, l_gt.squeeze(1))
                # lab_img_cycle_loss = self.L1(recon_lab_img, l_img) * args.lamda

                # Total generators losses
                ###################################################
                # lab_loss_CE = self.CE(lab_gt, l_gt.squeeze(1))
                lab_loss_MSE = self.L1(fake_img, l_img)
                # lab_loss_perceptual = perceptual_loss(fake_img, l_img, args.gpu_ids)

                fullsupervisedloss = args.lab_CE_weight * lab_loss_CE + args.lab_MSE_weight * lab_loss_MSE

                unsupervisedloss = args.adversarial_weight * (
                    img_gen_loss + gt_gen_loss
                ) + resnet_img_cycle_loss + gt_cycle_loss * args.lamda_gt

                gen_loss = fullsupervisedloss + unsupervisedloss

                # Update generators
                ###################################################
                gen_loss.backward()

                self.g_optimizer.step()

                if counter % 1 == 0:
                    # Discriminator Computations
                    #################################################

                    set_grad([self.Di, self.Ds, self.old_Di], True)
                    self.d_optimizer.zero_grad()

                    # Sample from history of generated images
                    #################################################
                    if torch.rand(1) < 0.0:
                        fake_img = self.gauss_noise(fake_img.cpu())
                        fake_gt = self.gauss_noise(fake_gt.cpu())

                    recon_img = Variable(
                        torch.Tensor(
                            new_img_fake_sample([recon_img.cpu().data.numpy()
                                                 ])[0]))
                    fake_img = Variable(
                        torch.Tensor(
                            img_fake_sample([fake_img.cpu().data.numpy()])[0]))
                    # lab_gt = Variable(torch.Tensor(gt_fake_sample([lab_gt.cpu().data.numpy()])[0]))
                    fake_gt = Variable(
                        torch.Tensor(
                            gt_fake_sample([fake_gt.cpu().data.numpy()])[0]))

                    recon_img, fake_img, fake_gt = utils.cuda(
                        [recon_img, fake_img, fake_gt], args.gpu_ids)

                    # Forward pass through discriminators
                    #################################################
                    unl_img_dis = self.Di(unl_img)
                    fake_img_dis = self.Di(fake_img)
                    resnet_recon_img_dis = self.old_Di(resnet_recon_img)
                    resnet_fake_img_dis = self.old_Di(recon_img)

                    # lab_gt_dis = self.Ds(lab_gt)

                    l_gt = make_one_hot(l_gt, args.dataset, args.gpu_ids)
                    real_gt_dis = self.Ds(l_gt.float())

                    fake_gt_discriminator = fake_gt.data.max(1)[1].squeeze_(
                        1).squeeze_(0)
                    fake_gt_discriminator = fake_gt_discriminator.unsqueeze(1)
                    fake_gt_discriminator = make_one_hot(
                        fake_gt_discriminator, args.dataset, args.gpu_ids)
                    fake_gt_dis = self.Ds(fake_gt_discriminator.float())

                    real_label_img = utils.cuda(
                        Variable(torch.ones(unl_img_dis.size())), args.gpu_ids)
                    fake_label_img = utils.cuda(
                        Variable(torch.zeros(fake_img_dis.size())),
                        args.gpu_ids)
                    real_label_gt = utils.cuda(
                        Variable(torch.ones(real_gt_dis.size())), args.gpu_ids)
                    fake_label_gt = utils.cuda(
                        Variable(torch.zeros(fake_gt_dis.size())),
                        args.gpu_ids)

                    # Discriminator losses
                    ##################################################
                    img_dis_real_loss = self.MSE(unl_img_dis, real_label_img)
                    img_dis_fake_loss = self.MSE(fake_img_dis, fake_label_img)
                    gt_dis_real_loss = self.MSE(real_gt_dis, real_label_gt)
                    gt_dis_fake_loss = self.MSE(fake_gt_dis, fake_label_gt)
                    # lab_gt_dis_fake_loss = self.MSE(lab_gt_dis, fake_label)

                    cycle_img_dis_real_loss = self.MSE(resnet_recon_img_dis,
                                                       real_label_img)
                    cycle_img_dis_fake_loss = self.MSE(resnet_fake_img_dis,
                                                       fake_label_img)

                    # Total discriminators losses
                    img_dis_loss = (img_dis_real_loss +
                                    img_dis_fake_loss) * 0.5
                    gt_dis_loss = (gt_dis_real_loss + gt_dis_fake_loss) * 0.5
                    # lab_gt_dis_loss = (gt_dis_real_loss + lab_gt_dis_fake_loss)*0.33
                    cycle_img_dis_loss = cycle_img_dis_real_loss + cycle_img_dis_fake_loss

                    # Update discriminators
                    ##################################################
                    discriminator_loss = args.discriminator_weight * (
                        img_dis_loss + gt_dis_loss) + cycle_img_dis_loss
                    discriminator_loss.backward()

                    # lab_gt_dis_loss.backward()
                    self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Dis Loss:%.2e | Unlab Gen Loss:%.2e | Lab Gen loss:%.2e"
                    % (epoch, i + 1,
                       min(len(labeled_loader),
                           len(unlabeled_loader)), img_dis_loss + gt_dis_loss,
                       unsupervisedloss, fullsupervisedloss))

                self.writer_semisuper.add_scalars(
                    'Dis Loss', {
                        'img_dis_loss': img_dis_loss,
                        'gt_dis_loss': gt_dis_loss,
                        'cycle_img_dis_loss': cycle_img_dis_loss
                    },
                    len(labeled_loader) * epoch + i)
                self.writer_semisuper.add_scalars(
                    'Unlabelled Loss', {
                        'img_gen_loss': img_gen_loss,
                        'gt_gen_loss': gt_gen_loss,
                        'img_cycle_loss': resnet_img_cycle_loss,
                        'gt_cycle_loss': gt_cycle_loss
                    },
                    len(labeled_loader) * epoch + i)
                self.writer_semisuper.add_scalars(
                    'Labelled Loss', {
                        'lab_loss_CE': lab_loss_CE,
                        'lab_loss_MSE': lab_loss_MSE
                    },
                    len(labeled_loader) * epoch + i)

                counter += 1

            ### For getting the mean IoU
            self.Gsi.eval()
            self.Gis.eval()
            with torch.no_grad():
                for i, (val_img, val_gt, _) in enumerate(val_loader):
                    val_img, val_gt = utils.cuda([val_img, val_gt],
                                                 args.gpu_ids)

                    outputs = self.Gsi(val_img)
                    outputs = self.interp(outputs)
                    outputs = self.activation_softmax(outputs)

                    pred = outputs.data.max(1)[1].cpu().numpy()
                    gt = val_gt.squeeze().data.cpu().numpy()

                    self.running_metrics_val.update(gt, pred)

            score, class_iou = self.running_metrics_val.get_scores()

            self.running_metrics_val.reset()

            print('The mIoU for the epoch is: ', score["Mean IoU : \t"])

            ### For displaying the images generated by generator on tensorboard using validation images
            val_image, val_gt, _ = iter(val_loader).next()
            val_image, val_gt = utils.cuda([val_image, val_gt], args.gpu_ids)
            with torch.no_grad():
                fake_label = self.Gsi(val_image).detach()
                fake_label = self.interp(fake_label)
                fake_label = self.activation_softmax(fake_label)
                fake_label = fake_label.data.max(1)[1].squeeze_(1).squeeze_(0)
                fake_label = fake_label.unsqueeze(1)
                fake_label = make_one_hot(fake_label, args.dataset,
                                          args.gpu_ids)
                fake_img = self.Gis(fake_label).detach()
                fake_img = self.interp(fake_img)
                # fake_img = self.activation_tanh(fake_img)

                fake_img_from_labels = self.Gis(
                    make_one_hot(val_gt, args.dataset,
                                 args.gpu_ids).float()).detach()
                fake_img_from_labels = self.interp(fake_img_from_labels)
                # fake_img_from_labels = self.activation_tanh(fake_img_from_labels)
                fake_label_regenerated = self.Gsi(
                    fake_img_from_labels).detach()
                fake_label_regenerated = self.interp(fake_label_regenerated)
                fake_label_regenerated = self.activation_softmax(
                    fake_label_regenerated)
            fake_prediction_label = fake_label.data.max(1)[1].squeeze_(
                1).cpu().numpy()
            fake_regenerated_label = fake_label_regenerated.data.max(
                1)[1].squeeze_(1).cpu().numpy()
            val_gt = val_gt.cpu()

            fake_img = fake_img.cpu()
            fake_img_from_labels = fake_img_from_labels.cpu()
            ### Now i am going to revert back the transformation on these images
            if self.args.dataset == 'voc2012' or self.args.dataset == 'cityscapes':
                trans_mean = [0.5, 0.5, 0.5]
                trans_std = [0.5, 0.5, 0.5]
                for i in range(3):
                    fake_img[:, i, :, :] = (
                        (fake_img[:, i, :, :] * trans_std[i]) + trans_mean[i])
                    fake_img_from_labels[:, i, :, :] = (
                        (fake_img_from_labels[:, i, :, :] * trans_std[i]) +
                        trans_mean[i])

            elif self.args.dataset == 'acdc':
                trans_mean = [0.5]
                trans_std = [0.5]
                for i in range(1):
                    fake_img[:, i, :, :] = (
                        (fake_img[:, i, :, :] * trans_std[i]) + trans_mean[i])
                    fake_img_from_labels[:, i, :, :] = (
                        (fake_img_from_labels[:, i, :, :] * trans_std[i]) +
                        trans_mean[i])

            ### display_tensor is the final tensor that will be displayed on tensorboard
            display_tensor_label = torch.zeros([
                fake_label.shape[0], 3, fake_label.shape[2],
                fake_label.shape[3]
            ])
            display_tensor_gt = torch.zeros(
                [val_gt.shape[0], 3, val_gt.shape[2], val_gt.shape[3]])
            display_tensor_regen_label = torch.zeros([
                fake_label_regenerated.shape[0], 3,
                fake_label_regenerated.shape[2],
                fake_label_regenerated.shape[3]
            ])
            for i in range(fake_prediction_label.shape[0]):
                new_img_label = fake_prediction_label[i]
                new_img_label = utils.colorize_mask(
                    new_img_label, self.args.dataset
                )  ### So this is the generated image in PIL.Image format
                img_tensor_label = utils.PIL_to_tensor(new_img_label,
                                                       self.args.dataset)
                display_tensor_label[i, :, :, :] = img_tensor_label

                display_tensor_gt[i, :, :, :] = val_gt[i]

                regen_label = fake_regenerated_label[i]
                regen_label = utils.colorize_mask(regen_label,
                                                  self.args.dataset)
                regen_tensor_label = utils.PIL_to_tensor(
                    regen_label, self.args.dataset)
                display_tensor_regen_label[i, :, :, :] = regen_tensor_label

            self.writer_semisuper.add_image(
                'Generated segmented image: ',
                torchvision.utils.make_grid(display_tensor_label,
                                            nrow=2,
                                            normalize=True), epoch)
            self.writer_semisuper.add_image(
                'Generated image back from segmentation: ',
                torchvision.utils.make_grid(fake_img, nrow=2, normalize=True),
                epoch)
            self.writer_semisuper.add_image(
                'Ground truth for the image: ',
                torchvision.utils.make_grid(display_tensor_gt,
                                            nrow=2,
                                            normalize=True), epoch)
            self.writer_semisuper.add_image(
                'Image generated from val labels: ',
                torchvision.utils.make_grid(fake_img_from_labels,
                                            nrow=2,
                                            normalize=True), epoch)
            self.writer_semisuper.add_image(
                'Labels generated back from the cycle: ',
                torchvision.utils.make_grid(display_tensor_regen_label,
                                            nrow=2,
                                            normalize=True), epoch)

            if score["Mean IoU : \t"] >= self.best_iou:
                self.best_iou = score["Mean IoU : \t"]

                # Override the latest checkpoint
                #######################################################
                utils.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'Di': self.Di.state_dict(),
                        'Ds': self.Ds.state_dict(),
                        'Gis': self.Gis.state_dict(),
                        'Gsi': self.Gsi.state_dict(),
                        'd_optimizer': self.d_optimizer.state_dict(),
                        'g_optimizer': self.g_optimizer.state_dict(),
                        'best_iou': self.best_iou,
                        'class_iou': class_iou
                    }, '%s/latest_semisuper_cycleGAN.ckpt' %
                    (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()

        self.writer_semisuper.close()
示例#5
0
    def train(self, args):
        # For transforming the input image
        transform = transforms.Compose([
            # [transforms.RandomHorizontalFlip(),
            transforms.Resize((480, 1440)),
            #  transforms.RandomCrop((args.crop_height,args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Pytorch dataloader
        dataset = torch.utils.data.DataLoader(dsets.ImageFolder(
            '/train_merged/', transform=transform),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=4)
        # a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainA'], transform=transform),
        #                                                 batch_size=args.batch_size, shuffle=True, num_workers=4)
        # b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainB'], transform=transform),
        #                                                 batch_size=args.batch_size, shuffle=True, num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, x in enumerate(dataset):
                # step
                step = epoch * len(dataset) + i + 1

                # Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                x = Variable(x[0])
                x = utils.cuda([x])[0]
                shape = x.shape
                a_real, b_real = x[:, :, :, shape[3] // 2:], x[:, :, :,
                                                               shape[3] // 2:]
                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity losses
                ###################################################
                a_idt_loss = self.L1(a_idt,
                                     a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.L1(b_idt,
                                     b_real) * args.lamda * args.idt_coef

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda

                # Total generators losses
                ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                ###################################################
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %
                    (epoch, i + 1, len(dataset), gen_loss,
                     a_dis_loss + b_dis_loss))

            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
示例#6
0
    def __init__(self, hyperparameters):
        super(Model, self).__init__()

        self.device = hyperparameters['device']
        self.auxiliary_data_source = hyperparameters['auxiliary_data_source']
        self.all_data_sources = ['resnet_features', self.auxiliary_data_source]
        self.DATASET = hyperparameters['dataset']
        self.num_shots = hyperparameters['num_shots']
        self.latent_size = hyperparameters['latent_size']
        self.batch_size = hyperparameters['batch_size']
        self.hidden_size_rule = hyperparameters['hidden_size_rule']
        self.warmup = hyperparameters['model_specifics']['warmup']
        self.generalized = hyperparameters['generalized']
        self.classifier_batch_size = 32
        self.img_seen_samples = hyperparameters['samples_per_class'][
            self.DATASET][0]
        self.att_seen_samples = hyperparameters['samples_per_class'][
            self.DATASET][1]
        self.att_unseen_samples = hyperparameters['samples_per_class'][
            self.DATASET][2]
        self.img_unseen_samples = hyperparameters['samples_per_class'][
            self.DATASET][3]
        self.reco_loss_function = hyperparameters['loss']
        self.nepoch = hyperparameters['epochs']
        self.lr_cls = hyperparameters['lr_cls']
        self.cross_reconstruction = hyperparameters['model_specifics'][
            'cross_reconstruction']
        self.cls_train_epochs = hyperparameters['cls_train_steps']
        self.dataset = dataloader(self.DATASET,
                                  copy.deepcopy(self.auxiliary_data_source),
                                  device=self.device)
        self.writer = SummaryWriter()
        self.num_gen_iter = hyperparameters['num_gen_iter']
        self.num_dis_iter = hyperparameters['num_dis_iter']
        self.pretrain = hyperparameters['pretrain']

        if self.DATASET == 'CUB':
            self.num_classes = 200
            self.num_novel_classes = 50
        elif self.DATASET == 'SUN':
            self.num_classes = 717
            self.num_novel_classes = 72
        elif self.DATASET == 'AWA1' or self.DATASET == 'AWA2':
            self.num_classes = 50
            self.num_novel_classes = 10

        feature_dimensions = [2048, self.dataset.aux_data.size(1)]

        # Here, the encoders and decoders for all modalities are created and put into dict

        self.encoder = {}

        for datatype, dim in zip(self.all_data_sources, feature_dimensions):

            self.encoder[datatype] = models.encoder_template(
                dim, self.latent_size, self.hidden_size_rule[datatype],
                self.device)

            print(str(datatype) + ' ' + str(dim))

        print('latent size ' + str(self.latent_size))

        self.decoder = {}
        for datatype, dim in zip(self.all_data_sources, feature_dimensions):
            self.decoder[datatype] = models.decoder_template(
                self.latent_size, dim, self.hidden_size_rule[datatype],
                self.device)

        # An optimizer for all encoders and decoders is defined here
        parameters_to_optimize = list(self.parameters())
        for datatype in self.all_data_sources:
            parameters_to_optimize += list(self.encoder[datatype].parameters())
            parameters_to_optimize += list(self.decoder[datatype].parameters())

        # The discriminator network is defined here
        self.net_D_Att = models.Discriminator(
            self.dataset.aux_data.size(1) + 2048, self.device)
        self.net_D_Img = models.Discriminator(
            2048 + self.dataset.aux_data.size(1), self.device)

        self.optimizer_G = optim.Adam(parameters_to_optimize,
                                      lr=hyperparameters['lr_gen_model'],
                                      betas=(0.9, 0.999),
                                      eps=1e-08,
                                      weight_decay=0.0005,
                                      amsgrad=True)
        self.optimizer_D = optim.Adam(itertools.chain(
            self.net_D_Att.parameters(), self.net_D_Img.parameters()),
                                      lr=hyperparameters['lr_gen_model'],
                                      betas=(0.5, 0.999),
                                      weight_decay=0.0005)

        if self.reco_loss_function == 'l2':
            self.reconstruction_criterion = nn.MSELoss(reduction='sum')

        elif self.reco_loss_function == 'l1':
            self.reconstruction_criterion = nn.L1Loss(reduction='sum')

        self.MSE = nn.MSELoss(reduction='sum')
        self.L1 = nn.L1Loss(reduction='sum')

        self.att_fake_from_att_sample = utils.Sample_from_Pool()
        self.att_fake_from_img_sample = utils.Sample_from_Pool()
        self.img_fake_from_img_sample = utils.Sample_from_Pool()
        self.img_fake_from_att_sample = utils.Sample_from_Pool()

        if self.generalized:
            print('mode: gzsl')
            self.clf = LINEAR_LOGSOFTMAX(self.latent_size, self.num_classes)
        else:
            print('mode: zsl')
            self.clf = LINEAR_LOGSOFTMAX(self.latent_size,
                                         self.num_novel_classes)
示例#7
0
    def train(self, args):
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        a_loader = DataLoader(dsets.ImageFolder(dataset_dirs['trainA'],
                                                transform=transform),
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=4)
        b_loader = DataLoader(dsets.ImageFolder(dataset_dirs['trainB'],
                                                transform=transform),
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                a_real = a_real[0]
                b_real = b_real[0]
                a_real, b_real = utils.cuda([a_real, b_real])

                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                a_idt_loss = self.L1(a_idt, a_real) * 5.0
                b_idt_loss = self.L1(b_idt, b_real) * 5.0

                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(torch.ones(a_fake_dis.size()))

                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                a_cycle_loss = self.L1(a_recon, a_real) * 10.0
                b_cycle_loss = self.L1(b_recon, b_real) * 10.0

                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                gen_loss.backward()
                self.g_optimizer.step()

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                a_fake = torch.Tensor(
                    a_fake_sample([a_fake.cpu().data.numpy()])[0])
                b_fake = torch.Tensor(
                    b_fake_sample([b_fake.cpu().data.numpy()])[0])
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(torch.ones(a_real_dis.size()))
                fake_label = utils.cuda(torch.zeros(a_fake_dis.size()))

                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %
                    (epoch, i + 1, min(len(a_loader), len(b_loader)), gen_loss,
                     a_dis_loss + b_dis_loss))

            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
示例#8
0
    def train(self, args):
        train_set = AudioTransformSet(
            args.dataset_dir + "Joni_Mitchell/files.txt",
            args.dataset_dir + "Nancy_Sinatra/files.txt",
            args.seq_len,
            sampling_rate=22050,
            augment=True)
        dataloader = DataLoader(train_set,
                                batch_size=args.batch_size,
                                num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, data in enumerate(dataloader):
                # step

                step = epoch * len(dataloader) + i + 1
                print(step)

                a_real = data[0]
                b_real = data[1]

                a_real = a_real.cuda()
                b_real = b_real.cuda()

                a_r_spec = self.fft(a_real).detach()
                b_r_spec = self.fft(a_real).detach()

                print("Shape of a-spectrogram: {}".format(a_r_spec.size()))
                print("Shape of b-spectrogram: {}".format(b_r_spec.size()))
                # Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                # Forward pass through generators
                ##################################################
                a_fake = self.g_AB(b_r_spec.cuda())
                b_fake = self.g_BA(a_r_spec.cuda())

                a_f_spec = self.fft(a_fake).detach()
                b_f_spec = self.fft(b_fake).detach()

                print("Shape of a-fake spectrogram: {}".format(
                    a_f_spec.size()))
                print("Shape of b-fake spectrogram: {}".format(
                    b_f_spec.size()))

                a_recon = self.g_AB(b_f_spec)
                b_recon = self.g_BA(a_f_spec)

                a_recon_spec = self.fft(a_recon).detach()
                b_recon_spec = self.fft(b_recon).detach()

                a_idt = self.g_AB(a_r_spec.cuda())
                b_idt = self.g_BA(b_r_spec.cuda())

                a_idt = self.fft(a_idt).detach()
                b_idt = self.fft(b_idt).detach()

                print("Shape of a_recon spectrogram: {}".format(
                    a_recon.size()))
                print("Shape of b_recon spectrogram: {}".format(
                    b_recon.size()))

                # Identity losses
                ###################################################
                a_idt_loss = self.L1(a_idt,
                                     a_r_spec) * args.lamda * args.idt_coef
                b_idt_loss = self.L1(b_idt,
                                     b_r_spec) * args.lamda * args.idt_coef

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                print(a_fake_dis[2][6].size())
                real_label = utils.cuda(
                    Variable(torch.ones(a_fake_dis[2][6].size())))

                a_gen_loss = self.MSE(a_fake_dis[2][6], real_label)
                b_gen_loss = self.MSE(b_fake_dis[2][6], real_label)

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.L1(a_recon_spec, a_r_spec) * args.lamda
                b_cycle_loss = self.L1(b_recon_spec, b_r_spec) * args.lamda

                # Total generators losses
                ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                ###################################################
                gen_loss.backward(retain_graph=True)
                self.g_optimizer.step()

                # Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_f_spec = Variable(
                    torch.Tensor(
                        a_fake_sample([a_f_spec.cpu().data.numpy()])[0]))
                b_f_spec = Variable(
                    torch.Tensor(
                        b_fake_sample([b_f_spec.cpu().data.numpy()])[0]))
                a_f_spec, b_f_spec = utils.cuda([a_f_spec, b_f_spec])

                print("Shape of a-fake spectrogram: {}".format(
                    a_f_spec.size()))
                print("Shape of b-fake spectrogram: {}".format(
                    b_f_spec.size()))

                print("Shape of a-spectrogram: {}".format(a_r_spec.size()))
                print("Shape of b-spectrogram: {}".format(b_r_spec.size()))

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(
                    Variable(torch.ones(a_real_dis[2][6].size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis[2][6].size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.MSE(a_real_dis[2][6], real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis[2][6], fake_label)
                b_dis_real_loss = self.MSE(b_real_dis[2][6], real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis[2][6], fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward(retain_graph=True)
                b_dis_loss.backward(retain_graph=True)
                self.d_optimizer.step()

                # every 1000 mini-batches...

                # ...log the running loss
                writer.add_scalar('DisA loss', a_dis_loss / 1000,
                                  epoch * len(dataloader) + i)
                writer.add_scalar('DisB loss', b_dis_loss / 1000,
                                  epoch * len(dataloader) + i)

                writer.add_scalar('Generator loss', gen_loss / 1000,
                                  epoch * len(dataloader) + i)

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %
                    (epoch, i + 1, len(dataloader), gen_loss,
                     a_dis_loss + b_dis_loss))

            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.g_AB.state_dict(),
                    'Gba': self.g_BA.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
示例#9
0
    def train(self, args):
        # Image transforms
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Initialize dataloader
        a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainA'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainB'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        # live plot loss
        Gab_history = hl.History()
        Gba_history = hl.History()
        gan_history = hl.History()
        Da_history = hl.History()
        Db_history = hl.History()

        canvas = hl.Canvas()

        for epoch in range(self.start_epoch, args.epochs):
            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):

                # Identify step
                step = epoch * min(len(a_loader), len(b_loader)) + i + 1

                # Generators ===============================================================
                # Turning off grads for discriminators
                set_grad([self.Da, self.Db], False)

                # Zero out grads of the generator
                self.g_optimizer.zero_grad()

                # Real images from sets A and B
                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                # Passing through generators
                # Nomenclature. a_fake is fake image generated from b_real in the domain A.
                # NOTE: Gab generate a from b and vice versa
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                # Both generators should be able to generate the image in its own domain
                # give an input from its own domain
                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity loss
                a_idt_loss = self.L1(a_idt, a_real) * args.delta
                b_idt_loss = self.L1(b_idt, b_real) * args.delta

                # Adverserial loss
                # Da return 1 for an image in domain A
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                # Label expected here is 1 to fool the discriminator
                expected_label_a = utils.cuda(
                    Variable(torch.ones(a_fake_dis.size())))
                expected_label_b = utils.cuda(
                    Variable(torch.ones(b_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, expected_label_a)
                b_gen_loss = self.MSE(b_fake_dis, expected_label_b)

                # Cycle Consistency loss
                a_cycle_loss = self.L1(a_recon, a_real) * args.alpha
                b_cycle_loss = self.L1(b_recon, b_real) * args.alpha

                # Structural Cycle Consistency loss
                a_scyc_loss = self.ssim(a_recon, a_real) * args.beta
                b_scyc_loss = self.ssim(b_recon, b_real) * args.beta

                # Structure similarity loss
                # ba refers to the ssim scores between input and output generated by gen_ba
                # the gray image values range is 0-1
                gray = kornia.color.RgbToGrayscale()
                a_real_gray = gray((a_real + 1) / 2.0)
                a_fake_gray = gray((a_fake + 1) / 2.0)
                a_recon_gray = gray((a_recon + 1) / 2.0)
                b_real_gray = gray((b_real + 1) / 2.0)
                b_fake_gray = gray((b_fake + 1) / 2.0)
                b_recon_gray = gray((b_recon + 1) / 2.0)

                ba_ssim_loss = (
                    (self.ssim(a_real_gray, b_fake_gray)) +
                    (self.ssim(a_fake_gray, b_recon_gray))) * args.gamma
                ab_ssim_loss = (
                    (self.ssim(b_real_gray, a_fake_gray)) +
                    (self.ssim(b_fake_gray, a_recon_gray))) * args.gamma

                # Total Generator Loss
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_scyc_loss + b_scyc_loss + a_idt_loss + b_idt_loss + ba_ssim_loss + ab_ssim_loss

                # Update Generators
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminators ===========================================================
                # Turn on grads for discriminators
                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from previously generated fake images
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Pass through discriminators
                # Discriminator for domain A
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)

                # Discriminator for domain B
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)

                # Expected label for real image is 1
                exp_real_label_a = utils.cuda(
                    Variable(torch.ones(a_real_dis.size())))
                exp_fake_label_a = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                exp_real_label_b = utils.cuda(
                    Variable(torch.ones(b_real_dis.size())))
                exp_fake_label_b = utils.cuda(
                    Variable(torch.zeros(b_fake_dis.size())))

                # Discriminator losses
                a_real_dis_loss = self.MSE(a_real_dis, exp_real_label_a)
                a_fake_dis_loss = self.MSE(a_fake_dis, exp_fake_label_a)
                b_real_dis_loss = self.MSE(b_real_dis, exp_real_label_b)
                b_fake_dis_loss = self.MSE(b_fake_dis, exp_fake_label_b)

                # Total discriminator loss
                a_dis_loss = (a_fake_dis_loss + a_real_dis_loss) / 2
                b_dis_loss = (b_fake_dis_loss + b_real_dis_loss) / 2

                # Update discriminators
                a_dis_loss.backward()
                b_dis_loss.backward()

                self.d_optimizer.step()

                if i % args.log_freq == 0:
                    # Log losses
                    Gab_history.log(step,
                                    gen_loss=a_gen_loss,
                                    cycle_loss=a_cycle_loss,
                                    idt_loss=a_idt_loss,
                                    ssim_loss=ab_ssim_loss,
                                    scyc_loss=a_scyc_loss)

                    Gba_history.log(step,
                                    gen_loss=b_gen_loss,
                                    cycle_loss=b_cycle_loss,
                                    idt_loss=b_idt_loss,
                                    ssim_loss=ba_ssim_loss,
                                    scyc_loss=b_scyc_loss)

                    Da_history.log(step,
                                   loss=a_dis_loss,
                                   fake_loss=a_fake_dis_loss,
                                   real_loss=a_real_dis_loss)

                    Db_history.log(step,
                                   loss=b_dis_loss,
                                   fake_loss=b_fake_dis_loss,
                                   real_loss=b_real_dis_loss)

                    gan_history.log(step,
                                    gen_loss=gen_loss,
                                    dis_loss=(a_dis_loss + b_dis_loss))

                    print(
                        "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e"
                        % (epoch, i + 1, min(len(a_loader), len(b_loader)),
                           gen_loss, a_dis_loss + b_dis_loss))
                    with canvas:
                        canvas.draw_plot([
                            Gba_history['gen_loss'], Gba_history['cycle_loss'],
                            Gba_history['idt_loss'], Gba_history['ssim_loss'],
                            Gba_history['scyc_loss']
                        ],
                                         labels=[
                                             'Adv loss', 'Cycle loss',
                                             'Identity loss', 'SSIM',
                                             'SCyC loss'
                                         ])

                        canvas.draw_plot([
                            Gab_history['gen_loss'], Gab_history['cycle_loss'],
                            Gab_history['idt_loss'], Gab_history['ssim_loss'],
                            Gab_history['scyc_loss']
                        ],
                                         labels=[
                                             'Adv loss', 'Cycle loss',
                                             'Identity loss', 'SSIM',
                                             'SCyC loss'
                                         ])

                        canvas.draw_plot(
                            [
                                Db_history['loss'], Db_history['fake_loss'],
                                Db_history['real_loss']
                            ],
                            labels=['Loss', 'Fake Loss', 'Real Loss'])

                        canvas.draw_plot(
                            [
                                Da_history['loss'], Da_history['fake_loss'],
                                Da_history['real_loss']
                            ],
                            labels=['Loss', 'Fake Loss', 'Real Loss'])

                        canvas.draw_plot(
                            [gan_history['gen_loss'], gan_history['dis_loss']],
                            labels=['Generator loss', 'Discriminator loss'])

            # Overwrite checkpoint
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_path))

            # Save loss history
            history_path = args.results_path + '/loss_history/'
            utils.mkdir([history_path])
            Gab_history.save(history_path + "Gab.pkl")
            Gba_history.save(history_path + "Gba.pkl")
            Da_history.save(history_path + "Da.pkl")
            Db_history.save(history_path + "Db.pkl")
            gan_history.save(history_path + "gan.pkl")

            # Update learning rates
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()

            # Run one test cycle
            if args.testing:
                print('Testing')
                tst.test(args, epoch)
示例#10
0
    def train(self, args):
        # Test input
        transform_test = transforms.Compose([
            transforms.Resize((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_testdata_link(args.dataset_dir)

        a_test_data = dsets.ImageFolder(dataset_dirs['testA'],
                                        transform=transform_test)
        b_test_data = dsets.ImageFolder(dataset_dirs['testB'],
                                        transform=transform_test)

        a_test_loader = torch.utils.data.DataLoader(a_test_data,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=4)
        b_test_loader = torch.utils.data.DataLoader(b_test_data,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=4)

        # For transforming the input image
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Pytorch dataloader
        a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainA'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainB'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                # step
                step = epoch * min(len(a_loader), len(b_loader)) + i + 1

                # Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity losses
                ###################################################
                a_idt_loss = self.L1(a_idt,
                                     a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.L1(b_idt,
                                     b_real) * args.lamda * args.idt_coef

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda

                # Total generators losses
                ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                ###################################################
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %
                    (epoch, i + 1, min(len(a_loader), len(b_loader)), gen_loss,
                     a_dis_loss + b_dis_loss))

            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))
            # Save image current :
            #######################################################################
            """ run """
            a_real_test = Variable(iter(a_test_loader).next()[0],
                                   requires_grad=True)
            b_real_test = Variable(iter(b_test_loader).next()[0],
                                   requires_grad=True)
            a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])

            self.Gab.eval()
            self.Gba.eval()

            with torch.no_grad():
                a_fake_test = self.Gab(b_real_test)
                b_fake_test = self.Gba(a_real_test)
                a_recon_test = self.Gab(b_fake_test)
                b_recon_test = self.Gba(a_fake_test)

            pic = (torch.cat([
                a_real_test, b_fake_test, a_recon_test, b_real_test,
                a_fake_test, b_recon_test
            ],
                             dim=0).data + 1) / 2.0

            if not os.path.isdir(args.results_dir):
                os.makedirs(args.results_dir)

            torchvision.utils.save_image(pic,
                                         args.results_dir +
                                         '/sample_{}.jpg'.format(epoch),
                                         nrow=3)

            self.Gab.train()
            self.Gba.train()
            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
示例#11
0
    def train(self, args):
        # data transformation
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Dataloader for class A and B
        a_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            dataset_dirs['trainA'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            dataset_dirs['trainB'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        # get fake samples from the sample pool
        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                '''
                Generator First, Discriminator Second
                '''
                # Generator Optimization
                set_grad([self.D_A, self.D_B], False)
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                a_fake = self.G_BtoA(b_real)
                b_fake = self.G_AtoB(a_real)

                a_recon = self.G_BtoA(b_fake)
                b_recon = self.G_AtoB(a_fake)

                a_idt = self.G_BtoA(a_real)
                b_idt = self.G_AtoB(b_real)

                # Identity losses
                a_idt_loss = self.L1(a_idt,
                                     a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.L1(b_idt,
                                     b_real) * args.lamda * args.idt_coef

                # Adversarial losses
                a_fake_dis = self.D_A(a_fake)
                b_fake_dis = self.D_B(b_fake)

                a_real_label = utils.cuda(
                    Variable(torch.ones(a_fake_dis.size())))
                b_real_label = utils.cuda(
                    Variable(torch.ones(b_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, a_real_label)
                b_gen_loss = self.MSE(b_fake_dis, b_real_label)

                # Cycle consistency losses
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda

                # Total generators losses
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminator Optimization
                set_grad([self.D_A, self.D_B], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                a_real_dis = self.D_A(a_real)
                a_fake_dis = self.D_A(a_fake)
                b_real_dis = self.D_B(b_real)
                b_fake_dis = self.D_B(b_fake)

                a_real_label = utils.cuda(
                    Variable(torch.ones(a_real_dis.size())))
                a_fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))
                b_real_label = utils.cuda(
                    Variable(torch.ones(b_real_dis.size())))
                b_fake_label = utils.cuda(
                    Variable(torch.zeros(b_fake_dis.size())))

                # Discriminator losses
                a_dis_real_loss = self.MSE(a_real_dis, a_real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, a_fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, b_real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, b_fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                # print some information
                if (i + 1) % 20 == 0:
                    print(
                        "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e"
                        % (epoch, i + 1, min(len(a_loader), len(b_loader)),
                           gen_loss, a_dis_loss + b_dis_loss))

            # Update the checkpoint
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'D_A': self.D_A.state_dict(),
                    'D_B': self.D_B.state_dict(),
                    'G_AtoB': self.G_AtoB.state_dict(),
                    'G_BtoA': self.G_BtoA.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
示例#12
0
    def train(self, args):
        # For transforming the input image
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            # transforms.Resize((args.load_height,args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)
        dataset_a = ListDataSet(
            '/media/l/新加卷1/city/data/river/train_256_9w.lst',
            transform=transform)
        dataset_b = ListDataSet('/media/l/新加卷/city/jinan_z3.lst',
                                transform=transform)
        # Pytorch dataloader
        a_loader = torch.utils.data.DataLoader(dataset_a,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dataset_b,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                # step
                step = epoch * min(len(a_loader), len(b_loader)) + i + 1

                # Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity losses
                ###################################################
                a_idt_loss = self.L1(a_idt,
                                     a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.L1(b_idt,
                                     b_real) * args.lamda * args.idt_coef

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda

                # Total generators losses
                ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                ###################################################
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.4f | Dis Loss:%.4f" %
                    (epoch, i + 1, min(len(a_loader), len(b_loader)), gen_loss,
                     a_dis_loss + b_dis_loss))

            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict(),
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()
示例#13
0
    def train(self, args):
        # For transforming the input image
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        test_transform = transforms.Compose([
            transforms.Resize((args.test_crop_height, args.test_crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)
        testset_dirs = utils.get_testdata_link(args.dataset_dir)

        # Pytorch dataloader
        a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainA'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainB'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_test_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            testset_dirs['testA'], transform=test_transform),
                                                    batch_size=1,
                                                    shuffle=False,
                                                    num_workers=4)
        b_test_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            testset_dirs['testB'], transform=test_transform),
                                                    batch_size=1,
                                                    shuffle=False,
                                                    num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        for epoch in range(self.start_epoch, args.epochs):

            if epoch >= 1:
                print('generating test result...')
                self.save_sample_image(args.test_length, a_test_loader,
                                       b_test_loader, args.results_dir, epoch)

            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)
            running_Gen_loss = 0
            running_Dis_loss = 0
            ##################################################
            # BEGIN TRAINING FOR ONE EPOCH
            ##################################################
            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
                # step
                step = epoch * min(len(a_loader), len(b_loader)) + i + 1

                ##################################################
                # Part 1: Generator Computations
                ##################################################

                set_grad([self.Da, self.Db], False)
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                # Forward pass through generators
                ##################################################
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity losses
                ###################################################
                a_idt_loss = self.identity_criteron(
                    a_idt, a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.identity_criteron(
                    b_idt, b_real) * args.lamda * args.idt_coef
                # a_idt_loss = 0
                # b_idt_loss = 0

                # Adversarial losses
                ###################################################
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                real_label = utils.cuda(Variable(torch.ones(
                    a_fake_dis.size())))

                a_gen_loss = self.adversarial_criteron(a_fake_dis, real_label)
                b_gen_loss = self.adversarial_criteron(b_fake_dis, real_label)

                # Cycle consistency losses
                ###################################################
                a_cycle_loss = self.cycle_consistency_criteron(
                    a_recon, a_real) * args.lamda
                b_cycle_loss = self.cycle_consistency_criteron(
                    b_recon, b_real) * args.lamda

                # Total generators losses
                ###################################################
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # Update generators
                ###################################################
                gen_loss.backward()
                self.g_optimizer.step()

                ##################################################
                # Part 2: Discriminator Computations
                #################################################

                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from history of generated images
                #################################################
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Forward pass through discriminators
                #################################################
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = utils.cuda(Variable(torch.ones(
                    a_real_dis.size())))
                fake_label = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # Discriminator losses
                ##################################################
                a_dis_real_loss = self.adversarial_criteron(
                    a_real_dis, real_label)
                a_dis_fake_loss = self.adversarial_criteron(
                    a_fake_dis, fake_label)
                b_dis_real_loss = self.adversarial_criteron(
                    b_real_dis, real_label)
                b_dis_fake_loss = self.adversarial_criteron(
                    b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Update discriminators
                ##################################################
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %
                    (epoch, i + 1, min(len(a_loader), len(b_loader)), gen_loss,
                     a_dis_loss + b_dis_loss))
                running_Gen_loss += gen_loss
                running_Dis_loss += (a_dis_loss + b_dis_loss)
            ##################################################
            # END TRAINING FOR ONE EPOCH
            ##################################################
            self.writer.add_scalar(
                'Gen Loss',
                running_Gen_loss / min(len(a_loader), len(b_loader)), epoch)
            self.writer.add_scalar(
                'Dis Loss',
                running_Dis_loss / min(len(a_loader), len(b_loader)), epoch)
            self.writer.add_scalar('Gen_LR',
                                   self.g_lr_scheduler.get_lr()[0], epoch)
            self.writer.add_scalar('Dis_LR',
                                   self.d_lr_scheduler.get_lr()[0], epoch)
            # Override the latest checkpoint
            #######################################################
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_dir))

            # Update learning rates
            ########################
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()

        self.writer.close()