Ejemplo n.º 1
0
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.feature_extractor = resnet50()  # Already pretrained
        # self.feature_extractor = resnet50(pretrained_path=None)
        self.selector = Selector()
        self.dis = Discriminator()
        self.optmzr_select = Adam(self.selector.parameters(), lr=1e-3)
        self.optmzr_dis = Adam(self.dis.parameters(), lr=1e-3)

    def forward(self, anchor: Variable, real_data: Variable, fake_data: Variable):
        assert len(anchor.size()) == 4 and len(anchor.size()) == 4

        fea_anchor = self.feature_extractor(anchor)
        fea_real = self.feature_extractor(real_data)
        fea_fake = self.feature_extractor(fake_data)

        # not train_feature:
        fea_anchor = fea_anchor.detach()
        fea_real = fea_real.detach()
        fea_fake = fea_fake.detach()

        score_real = self.dis(fea_anchor, fea_real)
        score_fake = self.dis(fea_anchor, fea_fake)

        return score_real, score_fake

    def bp_dis(self, score_real, score_fake):
        real_label = Variable(torch.normal(torch.ones(score_real.size()), torch.zeros(score_real.size()) + 0.05)).cuda()
        fake_label = Variable(
            torch.normal(torch.zeros(score_real.size()), torch.zeros(score_real.size()) + 0.05)).cuda()
        loss = torch.mean(F.binary_cross_entropy(score_real, real_label, size_average=False) + \
                          F.binary_cross_entropy(score_fake, fake_label, size_average=False))

        # loss = -(torch.mean(torch.log(score_real + 1e-6)) - torch.mean(torch.log(.5 + score_fake / 2 + 1e-6)))

        self.optmzr_dis.zero_grad()
        loss.backward()
        return self.optmzr_dis.step()

    def bp_select(self, score_fake: Variable, fake_prob):
        # torch.mean(torch.log(prob) * torch.log(1 - score_fake), 0)
        n_sample = score_fake.size()[0]
        self.optmzr_dis.zero_grad()
        re = (score_fake.data - .5) * 2
        torch.log(fake_prob).backward(re / n_sample)
Ejemplo n.º 2
0
class CoCosModel(BaseModel):
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        return parser

    @staticmethod
    def torch2numpy(x):
        # from [-1,1] to [0,255]
        return ((x.detach().cpu().numpy().transpose(1, 2, 0) + 1) *
                127.5).astype(np.uint8)

    def __name__(self):
        return 'CoCosModel'

    def __init__(self, opt):
        super().__init__(opt)
        self.w = opt.image_size
        # make a folder for save images
        self.image_dir = os.path.join(self.save_dir, 'images')
        if not os.path.isdir(self.image_dir):
            os.mkdir(self.image_dir)

        # initialize networks
        self.model_names = ['C', 'T']
        self.netC = CorrespondenceNet(opt)
        self.netT = TranslationNet(opt)
        if opt.isTrain:
            self.model_names.append('D')
            self.netD = Discriminator(opt)

        self.visual_names = ['b_exemplar', 'a', 'b_gen',
                             'b_gt']  # HPT convention

        if opt.isTrain:
            # assign losses
            self.loss_names = [
                'perc', 'domain', 'feat', 'context', 'reg', 'adv'
            ]
            self.visual_names += ['b_warp']
            self.criterionFeat = torch.nn.L1Loss()
            # Both interface for VGG and perceptual loss
            # call with different mode and layer params
            self.criterionVGG = VGGLoss(self.device)
            # Support hinge loss
            self.criterionAdv = GANLoss(gan_mode=opt.gan_mode).to(self.device)
            self.criterionDomain = nn.L1Loss()
            self.criterionReg = torch.nn.L1Loss()

            # initialize optimizers
            gen_params = itertools.chain(self.netT.parameters(),
                                         self.netC.parameters())
            self.optG = torch.optim.Adam(gen_params,
                                         lr=opt.lr,
                                         betas=(opt.beta1, 0.999))
            self.optD = torch.optim.Adam(self.netD.parameters(),
                                         lr=opt.lr,
                                         betas=(opt.beta1, 0.999))
            self.optimizers = [self.optG, self.optD]

        # Finally, load checkpoints and recover schedulers
        self.setup(opt)
        torch.autograd.set_detect_anomaly(True)

    def set_input(self, batch):
        # expecting 'a' -> 'b_gt', 'a_exemplar' -> 'b_exemplar', ('b_deform')
        # for human pose transfer, 'b_deform' is already 'b_exemplar'
        for k, v in batch.items():
            setattr(self, k, v.to(self.device))

    def forward(self):
        self.sa, self.sb, self.fb_warp, self.b_warp = self.netC(
            self.a, self.b_exemplar)  # 3*HW*HW
        self.b_gen = self.netT(self.b_warp)
        # self.b_gen = self.netT(self.fb_warp) retain original feature or use warped rgb?

        # TODO: Implement backward warping (maybe we should adjust the input size?)
        _, _, _, self.b_reg = self.netC(
            self.a_exemplar,
            F.interpolate(self.b_warp, (self.w, self.w), mode='bilinear'))
        #print(self.b_gen.shape, self.b_reg.shape, self.b_gt.shape)

    def test(self):
        with torch.no_grad():
            _, _, _, self.b_warp = self.netC(self.a,
                                             self.b_exemplar)  # 3*HW*HW
            self.b_gen = self.netT(self.b_warp)

    def backward_G(self):
        self.optG.zero_grad()
        # Damn, do we really need 6 losses?
        # 1. Perc loss(For human pose transfer we abandon it, it's all in the criterion Feat)
        self.loss_perc = 0
        # 2. domain loss
        self.loss_domain = self.opt.lambda_domain * self.criterionDomain(
            self.sa, self.sb)
        # 3. losses for pseudo exemplar pairs
        self.loss_feat = self.opt.lambda_feat * self.criterionVGG(
            self.b_gen, self.b_gt, mode='perceptual')
        # 4. Contextural loss
        self.loss_context = self.opt.lambda_context * self.criterionVGG(
            self.b_gen,
            self.b_exemplar,
            mode='contextual',
            layers=[2, 3, 4, 5])
        # 5. Reg loss
        b_exemplar_small = F.interpolate(self.b_exemplar,
                                         self.b_reg.size()[2:],
                                         mode='bilinear')
        self.loss_reg = self.opt.lambda_reg * self.criterionReg(
            self.b_reg, b_exemplar_small)
        # 6. GAN loss
        pred_real, pred_fake = self.discriminate(self.b_gt, self.b_gen)
        self.loss_adv = self.opt.lambda_adv * self.criterionAdv(
            pred_fake, True, for_discriminator=False)

        g_loss = self.loss_perc + self.loss_domain + self.loss_feat \
            + self.loss_context + self.loss_reg + self.loss_adv

        g_loss.backward()
        self.optG.step()

    def discriminate(self, real, fake):
        fake_and_real = torch.cat([fake, real], dim=0)
        discriminator_out = self.netD(fake_and_real)
        pred_fake, pred_real = self.divide_pred(discriminator_out)

        return pred_fake, pred_real

    # Take the prediction of fake and real images from the combined batch
    def divide_pred(self, pred):
        # the prediction contains the intermediate outputs of multiscale GAN,
        # so it's usually a list
        if isinstance(pred, list):
            fake = [p[:p.size(0) // 2] for p in pred]
            real = [p[p.size(0) // 2:] for p in pred]
        else:
            fake = pred[:pred.size(0) // 2]
            real = pred[pred.size(0) // 2:]

        return fake, real

    def backward_D(self):
        self.optD.zero_grad()
        # test, run under no_grad mode
        self.test()

        pred_fake, pred_real = self.discriminate(self.b_gt, self.b_gen)

        self.d_fake = self.criterionAdv(pred_fake,
                                        False,
                                        for_discriminator=True)
        self.d_real = self.criterionAdv(pred_real,
                                        True,
                                        for_discriminator=True)

        d_loss = (self.d_fake + self.d_real) / 2
        d_loss.backward()
        self.optD.step()

    def optimize_parameters(self):
        # must call self.set_input(data) first
        self.forward()
        self.backward_G()
        self.backward_D()

    ### Standalone utility functions
    def log_loss(self, epoch, iter):
        msg = 'Epoch %d iter %d\n  ' % (epoch, iter)
        for name in self.loss_names:
            val = getattr(self, 'loss_%s' % name)
            if isinstance(val, torch.cuda.FloatTensor):
                val = val.item()
            msg += '%s: %.4f, ' % (name, val)
        print(msg)

    def log_visual(self, epoch, iter):
        save_path = os.path.join(self.save_image_dir,
                                 'epoch%03d_iter%05d.png' % (epoch, iter))
        # warped image is not the same resolution, need scaling
        self.b_warp = F.interpolate(self.b_warp, (self.w, self.w),
                                    mode='bicubic')
        pack = torch.cat([getattr(self, name) for name in self.visual_names],
                         dim=3)[0]  # only save one example
        cv2.imwrite(save_path, self.torch2numpy(pack))
        cv2.imwrite('b_ex' + save_path, self.torch2numpy(self.b_exemplar[0]))

    def update_learning_rate(self):
        '''
            Update learning rates for all the networks;
            called at the end of every epoch by train.py
        '''
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate updated to %.7f' % lr)
Ejemplo n.º 3
0
class Trainer(object):
    def __init__(self, config, args):
        self.args = args
        self.config = config
        self.visdom = args.visdom
        if args.visdom:
            self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1], port=8888)
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            config)
        self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader(
            config)

        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)

        self.D = Discriminator(num_classes=self.nclass, ndf=16)

        train_params = [{
            'params': self.model.get_1x_lr_params(),
            'lr': config.lr
        }, {
            'params': self.model.get_10x_lr_params(),
            'lr': config.lr * config.lr_ratio
        }]

        # Define Optimizer
        self.optimizer = torch.optim.SGD(train_params,
                                         momentum=config.momentum,
                                         weight_decay=config.weight_decay)
        self.D_optimizer = torch.optim.Adam(self.D.parameters(),
                                            lr=config.lr,
                                            betas=(0.9, 0.99))

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        self.entropy_mini_loss = MinimizeEntropyLoss()
        self.bottleneck_loss = BottleneckLoss()
        self.instance_loss = InstanceLoss()
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler,
                                      config.lr, config.epochs,
                                      len(self.train_loader), config.lr_step,
                                      config.warmup_epochs)
        self.summary = TensorboardSummary('./train_log')
        # labels for adversarial training
        self.source_label = 0
        self.target_label = 1

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

            self.D = torch.nn.DataParallel(self.D)
            patch_replication_callback(self.D)
            self.D = self.D.cuda()

        self.best_pred_source = 0.0
        self.best_pred_target = 0.0
        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint,
                                           map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))

    def training(self, epoch):
        train_loss, seg_loss_sum, bn_loss_sum, entropy_loss_sum, adv_loss_sum, d_loss_sum, ins_loss_sum = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
        self.model.train()
        if config.freeze_bn:
            self.model.module.freeze_bn()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        target_train_iterator = iter(self.target_train_loader)
        for i, sample in enumerate(tbar):
            itr = epoch * len(self.train_loader) + i
            #if self.visdom:
            #    self.vis.line(X=torch.tensor([itr]), Y=torch.tensor([self.optimizer.param_groups[0]['lr']]),
            #              win='lr', opts=dict(title='lr', xlabel='iter', ylabel='lr'),
            #              update='append' if itr>0 else None)
            self.summary.writer.add_scalar(
                'Train/lr', self.optimizer.param_groups[0]['lr'], itr)
            A_image, A_target = sample['image'], sample['label']

            # Get one batch from target domain
            try:
                target_sample = next(target_train_iterator)
            except StopIteration:
                target_train_iterator = iter(self.target_train_loader)
                target_sample = next(target_train_iterator)

            B_image, B_target, B_image_pair = target_sample[
                'image'], target_sample['label'], target_sample['image_pair']

            if self.args.cuda:
                A_image, A_target = A_image.cuda(), A_target.cuda()
                B_image, B_target, B_image_pair = B_image.cuda(
                ), B_target.cuda(), B_image_pair.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred_source,
                           self.best_pred_target, self.config.lr_ratio)
            self.scheduler(self.D_optimizer, i, epoch, self.best_pred_source,
                           self.best_pred_target, self.config.lr_ratio)

            A_output, A_feat, A_low_feat = self.model(A_image)
            B_output, B_feat, B_low_feat = self.model(B_image)
            #B_output_pair, B_feat_pair, B_low_feat_pair = self.model(B_image_pair)
            #B_output_pair, B_feat_pair, B_low_feat_pair = flip(B_output_pair, dim=-1), flip(B_feat_pair, dim=-1), flip(B_low_feat_pair, dim=-1)

            self.optimizer.zero_grad()
            self.D_optimizer.zero_grad()

            # Train seg network
            for param in self.D.parameters():
                param.requires_grad = False

            # Supervised loss
            seg_loss = self.criterion(A_output, A_target)
            main_loss = seg_loss

            # Unsupervised loss
            #ins_loss = 0.01 * self.instance_loss(B_output, B_output_pair)
            #main_loss += ins_loss

            # Train adversarial loss
            D_out = self.D(prob_2_entropy(F.softmax(B_output)))
            adv_loss = bce_loss(D_out, self.source_label)

            main_loss += self.config.lambda_adv * adv_loss
            main_loss.backward()

            # Train discriminator
            for param in self.D.parameters():
                param.requires_grad = True
            A_output_detach = A_output.detach()
            B_output_detach = B_output.detach()
            # source
            D_source = self.D(prob_2_entropy(F.softmax(A_output_detach)))
            source_loss = bce_loss(D_source, self.source_label)
            source_loss = source_loss / 2
            # target
            D_target = self.D(prob_2_entropy(F.softmax(B_output_detach)))
            target_loss = bce_loss(D_target, self.target_label)
            target_loss = target_loss / 2
            d_loss = source_loss + target_loss
            d_loss.backward()

            self.optimizer.step()
            self.D_optimizer.step()

            seg_loss_sum += seg_loss.item()
            #ins_loss_sum += ins_loss.item()
            adv_loss_sum += self.config.lambda_adv * adv_loss.item()
            d_loss_sum += d_loss.item()

            #train_loss += seg_loss.item() + self.config.lambda_adv * adv_loss.item()
            train_loss += seg_loss.item()
            self.summary.writer.add_scalar('Train/SegLoss', seg_loss.item(),
                                           itr)
            #self.summary.writer.add_scalar('Train/InsLoss', ins_loss.item(), itr)
            self.summary.writer.add_scalar('Train/AdvLoss', adv_loss.item(),
                                           itr)
            self.summary.writer.add_scalar('Train/DiscriminatorLoss',
                                           d_loss.item(), itr)
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

            # Show the results of the last iteration
            #if i == len(self.train_loader)-1:
        print("Add Train images at epoch" + str(epoch))
        self.summary.visualize_image('Train-Source', self.config.dataset,
                                     A_image, A_target, A_output, epoch, 5)
        self.summary.visualize_image('Train-Target', self.config.target,
                                     B_image, B_target, B_output, epoch, 5)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.config.batch_size + A_image.data.shape[0]))
        print('Loss: %.3f' % train_loss)
        #print('Seg Loss: %.3f' % seg_loss_sum)
        #print('Ins Loss: %.3f' % ins_loss_sum)
        #print('BN Loss: %.3f' % bn_loss_sum)
        #print('Adv Loss: %.3f' % adv_loss_sum)
        #print('Discriminator Loss: %.3f' % d_loss_sum)

        #if self.visdom:
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([seg_loss_sum]), win='train_loss', name='Seg_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([ins_loss_sum]), win='train_loss', name='Ins_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([bn_loss_sum]), win='train_loss', name='BN_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([adv_loss_sum]), win='train_loss', name='Adv_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)
        #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([d_loss_sum]), win='train_loss', name='Dis_loss',
        #              opts=dict(title='loss', xlabel='epoch', ylabel='loss'),
        #              update='append' if epoch > 0 else None)

    def validation(self, epoch):
        def get_metrics(tbar, if_source=False):
            self.evaluator.reset()
            test_loss = 0.0
            #feat_mean, low_feat_mean, feat_var, low_feat_var = 0, 0, 0, 0
            #adv_loss = 0.0
            for i, sample in enumerate(tbar):
                image, target = sample['image'], sample['label']

                if self.args.cuda:
                    image, target = image.cuda(), target.cuda()

                with torch.no_grad():
                    output, low_feat, feat = self.model(image)

                #low_feat = low_feat.cpu().numpy()
                #feat = feat.cpu().numpy()

                #if isinstance(feat, np.ndarray):
                #    feat_mean += feat.mean(axis=0).mean(axis=1).mean(axis=1)
                #    low_feat_mean += low_feat.mean(axis=0).mean(axis=1).mean(axis=1)
                #    feat_var += feat.var(axis=0).var(axis=1).var(axis=1)
                #    low_feat_var += low_feat.var(axis=0).var(axis=1).var(axis=1)
                #else:
                #    feat_mean = feat.mean(axis=0).mean(axis=1).mean(axis=1)
                #    low_feat_mean = low_feat.mean(axis=0).mean(axis=1).mean(axis=1)
                #    feat_var = feat.var(axis=0).var(axis=1).var(axis=1)
                #    low_feat_var = low_feat.var(axis=0).var(axis=1).var(axis=1)

                #d_output = self.D(prob_2_entropy(F.softmax(output)))
                #adv_loss += bce_loss(d_output, self.source_label).item()
                loss = self.criterion(output, target)
                test_loss += loss.item()
                tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
                pred = output.data.cpu().numpy()

                target_ = target.cpu().numpy()
                pred = np.argmax(pred, axis=1)

                # Add batch sample into evaluator
                self.evaluator.add_batch(target_, pred)
            if if_source:
                print("Add Validation-Source images at epoch" + str(epoch))
                self.summary.visualize_image('Val-Source', self.config.dataset,
                                             image, target, output, epoch, 5)
            else:
                print("Add Validation-Target images at epoch" + str(epoch))
                self.summary.visualize_image('Val-Target', self.config.target,
                                             image, target, output, epoch, 5)
            #feat_mean /= (i+1)
            #low_feat_mean /= (i+1)
            #feat_var /= (i+1)
            #low_feat_var /= (i+1)
            #adv_loss /= (i+1)
            # Fast test during the training
            Acc = self.evaluator.Building_Acc()
            IoU = self.evaluator.Building_IoU()
            mIoU = self.evaluator.Mean_Intersection_over_Union()

            if if_source:
                print('Validation on source:')
            else:
                print('Validation on target:')
            print('[Epoch: %d, numImages: %5d]' %
                  (epoch, i * self.config.batch_size + image.data.shape[0]))
            print("Acc:{}, IoU:{}, mIoU:{}".format(Acc, IoU, mIoU))
            print('Loss: %.3f' % test_loss)

            if if_source:
                names = ['source', 'source_acc', 'source_IoU', 'source_mIoU']
                self.summary.writer.add_scalar('Val/SourceAcc', Acc, epoch)
                self.summary.writer.add_scalar('Val/SourceIoU', IoU, epoch)
            else:
                names = ['target', 'target_acc', 'target_IoU', 'target_mIoU']
                self.summary.writer.add_scalar('Val/TargetAcc', Acc, epoch)
                self.summary.writer.add_scalar('Val/TargetIoU', IoU, epoch)
            # Draw Visdom
            #if if_source:
            #    names = ['source', 'source_acc', 'source_IoU', 'source_mIoU']
            #else:
            #    names = ['target', 'target_acc', 'target_IoU', 'target_mIoU']

            #if self.visdom:
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([test_loss]), win='val_loss', name=names[0],
            #                  update='append')
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([adv_loss]), win='val_loss', name='adv_loss',
            #                  update='append')
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([Acc]), win='metrics', name=names[1],
            #                  opts=dict(title='metrics', xlabel='epoch', ylabel='performance'),
            #                  update='append' if epoch > 0 else None)
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([IoU]), win='metrics', name=names[2],
            #                  update='append')
            #    self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([mIoU]), win='metrics', name=names[3],
            #                  update='append')

            return Acc, IoU, mIoU

        self.model.eval()
        tbar_source = tqdm(self.val_loader, desc='\r')
        tbar_target = tqdm(self.target_val_loader, desc='\r')
        s_acc, s_iou, s_miou = get_metrics(tbar_source, True)
        t_acc, t_iou, t_miou = get_metrics(tbar_target, False)

        new_pred_source = s_iou
        new_pred_target = t_iou

        if new_pred_source > self.best_pred_source or new_pred_target > self.best_pred_target:
            is_best = True
            self.best_pred_source = max(new_pred_source, self.best_pred_source)
            self.best_pred_target = max(new_pred_target, self.best_pred_target)
        print('Saving state, epoch:', epoch)
        torch.save(
            self.model.module.state_dict(),
            self.args.save_folder + 'models/' + 'epoch' + str(epoch) + '.pth')
        loss_file = {
            's_Acc': s_acc,
            's_IoU': s_iou,
            's_mIoU': s_miou,
            't_Acc': t_acc,
            't_IoU': t_iou,
            't_mIoU': t_miou
        }
        with open(
                os.path.join(self.args.save_folder, 'eval',
                             'epoch' + str(epoch) + '.json'), 'w') as f:
            json.dump(loss_file, f)
Ejemplo n.º 4
0
def main():
    # dataset preparation
    source_data, target_data, val_data = create_dataset(mode='G2C')
    source_dataloader = Data.DataLoader(source_data,
                                        batch_size=parser.batch_size,
                                        shuffle=True,
                                        num_workers=parser.num_workers,
                                        pin_memory=True)
    target_dataloader = Data.DataLoader(target_data,
                                        batch_size=parser.batch_size,
                                        shuffle=True,
                                        num_workers=parser.num_workers,
                                        pin_memory=True)
    val_dataloader = Data.DataLoader(val_data,
                                     batch_size=parser.batch_size,
                                     shuffle=False,
                                     num_workers=parser.num_workers,
                                     pin_memory=True)
    source_dataloader_iter = enumerate(source_dataloader)
    target_dataloader_iter = enumerate(target_dataloader)

    save_dir = parser.ckpt_dir

    # create model and optimizer
    model = create_model(num_classes=parser.num_classes, name='DeepLab')
    D1 = Discriminator(num_classes=parser.num_classes)
    D2 = Discriminator(num_classes=parser.num_classes)

    optimizer_G = create_optimizer(model.get_optim_params(parser),
                                   lr=parser.learning_rate,
                                   momentum=parser.momentum,
                                   weight_decay=parser.weight_decay,
                                   name="SGD")
    optimizer_D1 = create_optimizer(D1.parameters(),
                                    lr=LEARNING_RATE_D,
                                    name="Adam",
                                    betas=BETAS)
    optimizer_D2 = create_optimizer(D2.parameters(),
                                    lr=LEARNING_RATE_D,
                                    name="Adam",
                                    betas=BETAS)

    optimizer_G.zero_grad()
    optimizer_D1.zero_grad()
    optimizer_D2.zero_grad()

    start_iter = 1
    last_mIoU = 0

    if parser.restore:
        print("loading checkpoint...")
        checkpoint = torch.load(save_dir)
        start_iter = checkpoint['iter']
        model.load_state_dict(checkpoint['model'])
        optimizer_G.load_state_dict(checkpoint['optimizer']['G'])
        optimizer_D1.load_state_dict(checkpoint['optimizer']['D1'])
        optimizer_D2.load_state_dict(checkpoint['optimizer']['D2'])
        last_mIoU = checkpoint['best_mIoU']

    print("start training...")
    print("pytorch version: " + TORCH_VERSION + ", cuda version: " +
          TORCH_CUDA_VERSION + ", cudnn version: " + CUDNN_VERSION)
    print("available graphical device: " + DEVICE_NAME)
    os.system("nvidia-smi")

    discriminator = {'D1': D1, 'D2': D2}
    optimizer = {'G': optimizer_G, 'D1': optimizer_D1, 'D2': optimizer_D2}

    best_mIoU, best_iter = train(model, discriminator, optimizer,
                                 source_dataloader_iter,
                                 target_dataloader_iter, val_dataloader,
                                 start_iter, last_mIoU)

    print("finished training, the best mIoU is: " + str(best_mIoU) +
          " in iteration " + str(best_iter))
Ejemplo n.º 5
0
if is_cuda:
    feature_extractor.cuda()
    dis.cuda()

# input pipeline
data_iter = DataProvider(batch_size, is_cuda=is_cuda)

# summary writer
if log_path:
    writer = SummaryWriter(log_path, 'comment test')
else:
    writer = None

# opt
opt_d = Adam(dis.parameters())
opt_fea = Adam(feature_extractor.parameters())


def train_dis():
    # label
    # real_label = Variable(torch.normal(torch.ones(batch_size), torch.zeros(batch_size) + 0.02)).cuda()
    # fake_label = Variable(torch.normal(torch.zeros(batch_size), torch.zeros(batch_size) + 0.02)).cuda()
    real_label = Variable(torch.ones(batch_size).cuda())
    fake_label = Variable(torch.zeros(batch_size).cuda())

    anchor, real_img, wrong_img = data_iter.next()
    anchor, real_img, wrong_img = Variable(anchor), Variable(
        real_img), Variable(wrong_img)
    fea_anc = feature_extractor(anchor)
    fea_real = feature_extractor(real_img)
Ejemplo n.º 6
0
class Model(pl.LightningModule):
    def __init__(self, hparams, device, G_AB, G_BA):
        super(Model, self).__init__()

        self.hparams = hparams
        self.device = device

        self.input_shape = hparams.input_shape
        self.learning_rate = hparams.learning_rate
        self.B1 = hparams.b1
        self.B2 = hparams.b1
        self.n_epochs = hparams.n_epochs
        self.start_epoch = hparams.start_epoch
        self.epoch_decay = hparams.epoch_decay
        self.batch_size = hparams.batch_size
        self.lambda_cycle_loss = hparams.lambda_cycle_loss
        self.lambda_identity_loss = hparams.lambda_identity_loss

        self.G_AB = G_AB
        self.G_BA = G_BA
        self.D_A = Discriminator(self.input_shape)
        self.D_B = Discriminator(self.input_shape)

        # Adversarial ground truths
        self.valid = torch.ones(
            (self.batch_size, *self.D_A.output_shape)).to(device)
        self.fake = torch.zeros(
            (self.batch_size, *self.D_A.output_shape)).to(device)

        # Losses
        self.criterion_GAN = torch.nn.MSELoss()
        self.criterion_cycle = torch.nn.L1Loss()
        self.criterion_identity = torch.nn.L1Loss()

        self.fake_A = None
        self.fake_B = None
        self.recov_A = None
        self.recov_B = None

    def forward(self, real_A, real_B):
        return self.G_AB(real_A), self.G_BA(real_B)

    def training_step(self, batch, batch_index, optimizer_index=0):
        loss = None
        loss_type = None

        # Set model input
        real_A = batch["A"].to(self.device)
        real_B = batch["B"].to(self.device)

        # -------------------------------
        #  Train Generators (G_AB, G_BA)
        # -------------------------------
        if optimizer_index == 0:
            # Identity loss
            loss_id_A = self.criterion_identity(self.G_BA(real_A), real_A)
            loss_id_B = self.criterion_identity(self.G_AB(real_B), real_B)

            loss_identity = (loss_id_A + loss_id_B) / 2

            # GAN loss
            self.fake_B, self.fake_A = self.forward(real_A, real_B)
            loss_GAN_AB = self.criterion_GAN(self.D_B(self.fake_B), self.valid)
            loss_GAN_BA = self.criterion_GAN(self.D_A(self.fake_A), self.valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            self.recov_B, self.recov_A = self.forward(self.fake_A, self.fake_B)

            # Cycle loss
            loss_cycle_A = self.criterion_cycle(self.recov_A, real_A)
            loss_cycle_B = self.criterion_cycle(self.recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Total loss
            loss = loss_GAN + self.lambda_cycle_loss * loss_cycle + self.lambda_identity_loss * loss_identity

            loss_type = 'G'

        # -----------------------
        #  Train Discriminator A
        # -----------------------
        elif optimizer_index == 1:

            # Real loss
            loss_real = self.criterion_GAN(self.D_A(real_A), self.valid)
            # Fake loss
            loss_fake = self.criterion_GAN(self.D_A(self.fake_A.detach()),
                                           self.fake)
            # Total loss
            loss = (loss_real + loss_fake) / 2
            loss_type = "D_A"

        # -----------------------
        #  Train Discriminator B
        # -----------------------
        elif optimizer_index == 2:
            # Real loss
            loss_real = self.criterion_GAN(self.D_B(real_B), self.valid)
            # Fake loss
            loss_fake = self.criterion_GAN(self.D_B(self.fake_B.detach()),
                                           self.fake)
            # Total loss
            loss = (loss_real + loss_fake) / 2
            loss_type = "D_B"

        tqdm_dict = {f"{loss_type}_loss": loss}

        return OrderedDict({
            'loss': loss,
            'progress_bar': tqdm_dict,
            'log': tqdm_dict
        })

    def validation_step(self, batch, batch_nb):
        if batch_nb == 0:
            self.sample_network_images(batch)

        loss_data = self.training_step(batch, batch_nb)

        return {
            'val_loss': loss_data['loss'],
            'progress_bar': loss_data['progress_bar'],
            'log': loss_data['log']
        }

    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}

    # def test_step(self, batch, batch_nb):
    #     # OPTIONAL
    #     x, y = batch
    #     y_hat = self.forward(x)
    #     return {'test_loss': F.cross_entropy(y_hat, y)}
    #
    # def test_end(self, outputs):
    #     # OPTIONAL
    #     avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
    #     return {'avg_test_loss': avg_loss}

    def configure_optimizers(self):
        # Optimizers
        optimizer_G = torch.optim.Adam(itertools.chain(self.G_AB.parameters(),
                                                       self.G_BA.parameters()),
                                       lr=self.learning_rate,
                                       betas=(self.B1, self.B2))
        optimizer_D_A = torch.optim.Adam(self.D_A.parameters(),
                                         lr=self.learning_rate,
                                         betas=(self.B1, self.B2))
        optimizer_D_B = torch.optim.Adam(self.D_B.parameters(),
                                         lr=self.learning_rate,
                                         betas=(self.B1, self.B2))

        # Learning rate update schedulers
        lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            optimizer_G,
            lr_lambda=LambdaLRSteper(self.n_epochs, self.start_epoch,
                                     self.epoch_decay).step)
        lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
            optimizer_D_A,
            lr_lambda=LambdaLRSteper(self.n_epochs, self.start_epoch,
                                     self.epoch_decay).step)
        lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
            optimizer_D_B,
            lr_lambda=LambdaLRSteper(self.n_epochs, self.start_epoch,
                                     self.epoch_decay).step)

        return [optimizer_G, optimizer_D_A, optimizer_D_B
                ], [lr_scheduler_G, lr_scheduler_D_A, lr_scheduler_D_B]

    @pl.data_loader
    def train_dataloader(self):
        return self.create_data_loader(MODE_TRAIN)

    @pl.data_loader
    def val_dataloader(self):
        return self.create_data_loader(MODE_VAL)

    @pl.data_loader
    def test_dataloader(self):
        return self.create_data_loader(MODE_TEST)

    def create_data_loader(self, mode):
        return DataLoader(
            ImageDataset(mode),
            batch_size=self.batch_size,
            shuffle=True,
            # num_workers=multiprocessing.cpu_count(),
        )

    def sample_network_images(self, batch):
        """Saves a generated sample from the test set"""
        real_A = batch["A"].to(self.device)
        real_B = batch["B"].to(self.device)
        fake_B, fake_A = self.forward(real_A, real_B)

        # Arrange images along x-axis
        real_A = make_grid(real_A, nrow=5, normalize=True)
        real_B = make_grid(real_B, nrow=5, normalize=True)
        fake_A = make_grid(fake_A, nrow=5, normalize=True)
        fake_B = make_grid(fake_B, nrow=5, normalize=True)

        # Arrange images along y-axis
        image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
        self.logger.experiment.add_image(f'sample_images_{self.current_epoch}',
                                         image_grid, 0)
Ejemplo n.º 7
0
        if transition_coef <= 1. and transition:

            generator = Generator(config['sr'][level_index],
                                  config,
                                  transition=transition,
                                  transition_coef=transition_coef).to(device)
            critic = Discriminator(config['sr'][level_index],
                                   config,
                                   transition=transition,
                                   transition_coef=transition_coef).to(device)

            generator_optimizer = torch.optim.Adam(
                generator.parameters(),
                lr=lr_gen,
                betas=config['generator_betas'])
            critic_optimizer = torch.optim.Adam(critic.parameters(),
                                                lr=lr_dis,
                                                betas=config['critic_betas'])

        elif transition_coef >= 1. and transition:

            transition = False
            generator = Generator(config['sr'][level_index],
                                  config,
                                  transition=transition,
                                  transition_coef=transition_coef).to(device)
            critic = Discriminator(config['sr'][level_index],
                                   config,
                                   transition=transition,
                                   transition_coef=transition_coef).to(device)
Ejemplo n.º 8
0
class Hidden:
    def __init__(self, configuration: HiDDenConfiguration,
                 device: torch.device, noiser: Noiser, tb_logger):
        """
        :param configuration: Configuration for the net, such as the size of the input image, number of channels in the intermediate layers, etc.
        :param device: torch.device object, CPU or GPU
        :param noiser: Object representing stacked noise layers.
        :param tb_logger: Optional TensorboardX logger object, if specified -- enables Tensorboard logging
        """
        super(Hidden, self).__init__()

        self.encoder_decoder = EncoderDecoder(configuration, noiser).to(device)
        self.optimizer_enc_dec = torch.optim.Adam(
            self.encoder_decoder.parameters())

        self.discriminator = Discriminator(configuration).to(device)
        self.optimizer_discrim = torch.optim.Adam(
            self.discriminator.parameters())

        if configuration.use_vgg:
            self.vgg_loss = VGGLoss(3, 1, False)
            self.vgg_loss.to(device)
        else:
            self.vgg_loss = None

        self.config = configuration
        self.device = device

        self.bce_with_logits_loss = nn.BCEWithLogitsLoss()
        self.mse_loss = nn.MSELoss()

        # Defined the labels used for training the discriminator/adversarial loss
        self.cover_label = 1
        self.encoded_label = 0

        self.tb_logger = tb_logger
        if tb_logger is not None:
            from tensorboard_logger import TensorBoardLogger
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            encoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/encoder_out'))
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            decoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/decoder_out'))
            discrim_final = self.discriminator._modules['linear']
            discrim_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/discrim_out'))

    def train_on_batch(self, batch: list):
        """
        Trains the network on a single batch consisting of images and messages
        :param batch: batch of training data, in the form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """
        images, messages = batch
        batch_size = images.shape[0]
        with torch.enable_grad():
            # ---------------- Train the discriminator -----------------------------
            self.optimizer_discrim.zero_grad()
            # train on cover
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)
            d_loss_on_cover.backward()

            # train on fake
            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            d_on_encoded = self.discriminator(encoded_images.detach())
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)
            d_loss_on_encoded.backward()
            self.optimizer_discrim.step()

            # --------------Train the generator (encoder-decoder) ---------------------
            self.optimizer_enc_dec.zero_grad()
            # target label for encoded images should be 'cover', because we want to fool the discriminator
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)
            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss == None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)


            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec
            g_loss.backward()
            self.optimizer_enc_dec.step()

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def validate_on_batch(self, batch: list):
        """
        Runs validation on a single batch of data consisting of images and messages
        :param batch: batch of validation data, in form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """

        # if TensorboardX logging is enabled, save some of the tensors.
        if self.tb_logger is not None:
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            self.tb_logger.add_tensor('weights/encoder_out',
                                      encoder_final.weight)
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            self.tb_logger.add_tensor('weights/decoder_out',
                                      decoder_final.weight)
            discrim_final = self.discriminator._modules['linear']
            self.tb_logger.add_tensor('weights/discrim_out',
                                      discrim_final.weight)

        images, messages = batch
        batch_size = images.shape[0]

        with torch.no_grad():
            d_on_cover = self.discriminator(images)
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)

            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            d_on_encoded = self.discriminator(encoded_images)
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)
            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss == None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)
            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def to_stirng(self):
        return '{}\n{}'.format(str(self.encoder_decoder),
                               str(self.discriminator))
def main(data_dir):
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToResize((416, 416)),
                              ToNormalize(FLAGS["normalize_mean"],
                                          FLAGS["normalize_std"])
                          ]))

    wlp300_dataloader = DataLoader(dataset=wlp300,
                                   batch_size=FLAGS['batch_size'],
                                   shuffle=True,
                                   num_workers=1)

    # 2) Intermediate Processing.
    transform_img = transforms.Compose([
        #transforms.ToTensor(),
        transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])
    ])

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    g_x = ResFCN256(resolution_input=416,
                    resolution_output=416,
                    channel=3,
                    size=16)
    g_y = ResFCN256(resolution_input=416,
                    resolution_output=416,
                    channel=3,
                    size=16)
    d_x = Discriminator()
    d_y = Discriminator()

    # Load the pre-trained weight
    if FLAGS['resume'] != "" and os.path.exists(
            os.path.join(FLAGS['pretrained'], FLAGS['resume'])):
        state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume']))
        try:
            g_x.load_state_dict(state['g_x'])
            g_y.load_state_dict(state['g_y'])
            d_x.load_state_dict(state['d_x'])
            d_y.load_state_dict(state['d_y'])
        except Exception:
            g_x.load_state_dict(state['prnet'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    g_x.to(FLAGS["device"])
    g_y.to(FLAGS["device"])
    d_x.to(FLAGS["device"])
    d_y.to(FLAGS["device"])

    optimizer_g = torch.optim.Adam(itertools.chain(g_x.parameters(),
                                                   g_y.parameters()),
                                   lr=FLAGS["lr"],
                                   betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(itertools.chain(d_x.parameters(),
                                                   d_y.parameters()),
                                   lr=FLAGS["lr"])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99)

    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])
    bce_loss = torch.nn.BCEWithLogitsLoss()
    bce_loss.to(FLAGS["device"])
    l1_loss = nn.L1Loss().to(FLAGS["device"])
    lambda_X = 10
    lambda_Y = 10
    #Loss function for adversarial
    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        loss_list_cycle_x = []
        loss_list_cycle_y = []
        loss_list_d_x = []
        loss_list_d_y = []
        real_label = torch.ones(FLAGS['batch_size'])
        fake_label = torch.zeros(FLAGS['batch_size'])
        for i, sample in enumerate(bar):
            real_y, real_x = sample['uv_map'].to(
                FLAGS['device']), sample['origin'].to(FLAGS['device'])
            # x -> y' -> x^
            optimizer_g.zero_grad()
            fake_y = g_x(real_x)
            prediction = d_x(fake_y)
            loss_g_x = bce_loss(prediction, real_label)
            x_hat = g_y(fake_y)
            loss_cycle_x = l1_loss(x_hat, real_x) * lambda_X
            loss_x = loss_g_x + loss_cycle_x
            loss_x.backward(retain_graph=True)
            optimizer_g.step()
            loss_list_cycle_x.append(loss_x.item())
            # y -> x' -> y^
            optimizer_g.zero_grad()
            fake_x = g_y(real_y)
            prediction = d_y(fake_x)
            loss_g_y = bce_loss(prediction, real_label)
            y_hat = g_x(fake_x)
            loss_cycle_y = l1_loss(y_hat, real_y) * lambda_Y
            loss_y = loss_g_y + loss_cycle_y
            loss_y.backward(retain_graph=True)
            optimizer_g.step()
            loss_list_cycle_y.append(loss_y.item())
            # d_x
            optimizer_d.zero_grad()
            pred_real = d_x(real_y)
            loss_d_x_real = bce_loss(pred_real, real_label)
            pred_fake = d_x(fake_y)
            loss_d_x_fake = bce_loss(pred_fake, fake_label)
            loss_d_x = (loss_d_x_real + loss_d_x_fake) * 0.5
            loss_d_x.backward()
            loss_list_d_x.append(loss_d_x.item())
            optimizer_d.step()
            if 'WGAN' in FLAGS['gan_type']:
                for p in d_x.parameters():
                    p.data.clamp_(-1, 1)
            # d_y
            optimizer_d.zero_grad()
            pred_real = d_y(real_x)
            loss_d_y_real = bce_loss(pred_real, real_label)
            pred_fake = d_y(fake_x)
            loss_d_y_fake = bce_loss(pred_fake, fake_label)
            loss_d_y = (loss_d_y_real + loss_d_y_fake) * 0.5
            loss_d_y.backward()
            loss_list_d_y.append(loss_d_y.item())
            optimizer_d.step()
            if 'WGAN' in FLAGS['gan_type']:
                for p in d_y.parameters():
                    p.data.clamp_(-1, 1)

        if ep % FLAGS["save_interval"] == 0:

            with torch.no_grad():
                print(
                    " {} [Loss_G_X] {} [Loss_G_Y] {} [Loss_D_X] {} [Loss_D_Y] {}"
                    .format(ep, loss_list_g_x[-1], loss_list_g_y[-1],
                            loss_list_d_x[-1], loss_list_d_y[-1]))
                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = np.load("./test_data/test_obama.npy")
                origin, gt_uv_map = test_data_preprocess(
                    origin), test_data_preprocess(gt_uv_map)

                origin, gt_uv_map = transform_img(origin), transform_img(
                    gt_uv_map)

                origin_in = origin.unsqueeze_(0).cuda()
                pred_uv_map = g_x(origin_in).detach().cpu()

                save_image(
                    [origin.cpu(),
                     gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                    os.path.join(FLAGS['images'],
                                 str(ep) + '.png'),
                    nrow=1,
                    normalize=True)

            # Save model
            print("Save model")
            state = {
                'g_x': g_x.state_dict(),
                'g_y': g_y.state_dict(),
                'd_x': d_x.state_dict(),
                'd_y': d_y.state_dict(),
                'start_epoch': ep,
            }
            torch.save(state, os.path.join(FLAGS['images'],
                                           '{}.pth'.format(ep)))

            scheduler.step()
Ejemplo n.º 10
0
def train(train_sources, eval_source):
    path = sys.argv[1]
    dr = DataReader(path, train_sources)
    dr.read()
    print(len(dr.train.x))

    batch_size = 8
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')

    dataset_s_train = MultiDomainDataset(dr.train.x, dr.train.y, dr.train.vendor, device, DomainAugmentation())
    dataset_s_dev = MultiDomainDataset(dr.dev.x, dr.dev.y, dr.dev.vendor, device)
    dataset_s_test = MultiDomainDataset(dr.test.x, dr.test.y, dr.test.vendor, device)
    loader_s_train = DataLoader(dataset_s_train, batch_size, shuffle=True)

    dr_eval = DataReader(path, [eval_source])
    dr_eval.read()

    dataset_eval_dev = MultiDomainDataset(dr_eval.dev.x, dr_eval.dev.y, dr_eval.dev.vendor, device)
    dataset_eval_test = MultiDomainDataset(dr_eval.test.x, dr_eval.test.y, dr_eval.test.vendor, device)

    dataset_da_train = MultiDomainDataset(dr.train.x+dr_eval.train.x, dr.train.y+dr_eval.train.y, dr.train.vendor+dr_eval.train.vendor, device, DomainAugmentation())
    loader_da_train = DataLoader(dataset_da_train, batch_size, shuffle=True)

    segmentator = UNet()
    discriminator = Discriminator(n_domains=len(train_sources))
    discriminator.to(device)
    segmentator.to(device)

    sigmoid = nn.Sigmoid()
    selector = Selector()

    s_criterion = nn.BCELoss()
    d_criterion = nn.CrossEntropyLoss()
    s_optimizer = optim.AdamW(segmentator.parameters(), lr=0.0001, weight_decay=0.01)
    d_optimizer = optim.AdamW(discriminator.parameters(), lr=0.001, weight_decay=0.01)
    a_optimizer = optim.AdamW(segmentator.encoder.parameters(), lr=0.001, weight_decay=0.01)
    lmbd = 1/150
    s_train_losses = []
    s_dev_losses = []
    d_train_losses = []
    eval_domain_losses = []
    train_dices = []
    dev_dices = []
    eval_dices = []
    epochs = 3
    da_loader_iter = iter(loader_da_train)
    for epoch in tqdm(range(epochs)):
        s_train_loss = 0.0
        d_train_loss = 0.0
        for index, sample in enumerate(loader_s_train):
            img = sample['image']
            target_mask = sample['target']

            da_sample = next(da_loader_iter, None)
            if epoch == 100:
                s_optimizer.defaults['lr'] = 0.001
                d_optimizer.defaults['lr'] = 0.0001
            if da_sample is None:
                da_loader_iter = iter(loader_da_train)
                da_sample = next(da_loader_iter, None)
            if epoch < 50 or epoch >= 100:
                # Training step of segmentator
                predicted_activations, inner_repr = segmentator(img)
                predicted_mask = sigmoid(predicted_activations)
                s_loss = s_criterion(predicted_mask, target_mask)
                s_optimizer.zero_grad()
                s_loss.backward()
                s_optimizer.step()
                s_train_loss += s_loss.cpu().detach().numpy()

            if epoch >= 50:
                # Training step of discriminator
                predicted_activations, inner_repr = segmentator(da_sample['image'])
                predicted_activations = predicted_activations.clone().detach()
                inner_repr = inner_repr.clone().detach()
                predicted_vendor = discriminator(predicted_activations, inner_repr)
                d_loss = d_criterion(predicted_vendor, da_sample['vendor'])
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()
                d_train_loss += d_loss.cpu().detach().numpy()

            if epoch >= 100:
                # adversarial training step
                predicted_mask, inner_repr = segmentator(da_sample['image'])
                predicted_vendor = discriminator(predicted_mask, inner_repr)
                a_loss = -1 * lmbd * d_criterion(predicted_vendor, da_sample['vendor'])
                a_optimizer.zero_grad()
                a_loss.backward()
                a_optimizer.step()
                lmbd += 1/150
        inference_model = nn.Sequential(segmentator, selector, sigmoid)
        inference_model.to(device)
        inference_model.eval()
        d_train_losses.append(d_train_loss / len(loader_s_train))
        s_train_losses.append(s_train_loss / len(loader_s_train))
        s_dev_losses.append(calculate_loss(dataset_s_dev, inference_model, s_criterion, batch_size))
        eval_domain_losses.append(calculate_loss(dataset_eval_dev, inference_model, s_criterion, batch_size))

        train_dices.append(calculate_dice(inference_model, dataset_s_train))
        dev_dices.append(calculate_dice(inference_model, dataset_s_dev))
        eval_dices.append(calculate_dice(inference_model, dataset_eval_dev))

        segmentator.train()

    date_time = datetime.now().strftime("%m%d%Y_%H%M%S")
    model_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "model", "weights", "segmentator"+str(date_time)+".pth")
    torch.save(segmentator.state_dict(), model_path)

    util.plot_data([(s_train_losses, 'train_losses'), (s_dev_losses, 'dev_losses'), (d_train_losses, 'discriminator_losses'),
               (eval_domain_losses, 'eval_domain_losses')],
              'losses.png')
    util.plot_dice([(train_dices, 'train_dice'), (dev_dices, 'dev_dice'), (eval_dices, 'eval_dice')],
              'dices.png')

    inference_model = nn.Sequential(segmentator, selector, sigmoid)
    inference_model.to(device)
    inference_model.eval()

    print('Dice on annotated: ', calculate_dice(inference_model, dataset_s_test))
    print('Dice on unannotated: ', calculate_dice(inference_model, dataset_eval_test))
Ejemplo n.º 11
0
def main():
    """Main function that trains and/or evaluates a model."""
    params = interpret_args()

    if params.gan:
        assert params.max_gen_len == params.train_maximum_sql_length \
               == params.eval_maximum_sql_length
        data = atis_data.ATISDataset(params)

        generator = SchemaInteractionATISModel(params, data.input_vocabulary,
                                               data.output_vocabulary,
                                               data.output_vocabulary_schema,
                                               None)

        generator = generator.cuda()

        generator.build_optim()

        if params.gen_from_ckp:
            gen_ckp_path = os.path.join(params.logdir, params.gen_pretrain_ckp)
            if params.fine_tune_bert:
                gen_epoch, generator, generator.trainer, \
                    generator.bert_trainer = \
                    load_ckp(
                        gen_ckp_path,
                        generator,
                        generator.trainer,
                        generator.bert_trainer
                    )
            else:
                gen_epoch, generator, generator.trainer, _ = \
                    load_ckp(
                        gen_ckp_path,
                        generator,
                        generator.trainer
                    )
        else:
            gen_epoch = 0

        print('====================Model Parameters====================')
        print('=======================Generator========================')
        for name, param in generator.named_parameters():
            print(name, param.requires_grad, param.is_cuda, param.size())
            assert param.is_cuda

        print('==================Optimizer Parameters==================')
        print('=======================Generator========================')
        for param_group in generator.trainer.param_groups:
            print(param_group.keys())
            for param in param_group['params']:
                print(param.size())

        if params.fine_tune_bert:
            print('=========================BERT===========================')
            for param_group in generator.bert_trainer.param_groups:
                print(param_group.keys())
                for param in param_group['params']:
                    print(param.size())

        sys.stdout.flush()

        # Pre-train generator with MLE
        if params.train:
            print('=============== Pre-training generator! ================')
            train(generator, data, params, gen_epoch)
            print('=========== Pre-training generator complete! ===========')

        dis_filter_sizes = [i for i in range(1, params.max_gen_len, 4)]
        dis_num_filters = [(100 + i * 10)
                           for i in range(1, params.max_gen_len, 4)]

        discriminator = Discriminator(params, data.dis_src_vocab,
                                      data.dis_tgt_vocab, params.max_gen_len,
                                      params.num_dis_classes, dis_filter_sizes,
                                      dis_num_filters, params.max_pos_emb,
                                      params.num_tok_type, params.dis_dropout)

        discriminator = discriminator.cuda()

        dis_criterion = nn.NLLLoss(reduction='mean')
        dis_criterion = dis_criterion.cuda()
        dis_optimizer = optim.Adam(discriminator.parameters())

        if params.dis_from_ckp:
            dis_ckp_path = os.path.join(params.logdir, params.dis_pretrain_ckp)
            dis_epoch, discriminator, dis_optimizer, _ = load_ckp(
                dis_ckp_path, discriminator, dis_optimizer)
        else:
            dis_epoch = 0

        print('====================Model Parameters====================')
        print('=====================Discriminator======================')
        for name, param in discriminator.named_parameters():
            print(name, param.requires_grad, param.is_cuda, param.size())
            assert param.is_cuda

        print('==================Optimizer Parameters==================')
        print('=====================Discriminator======================')
        for param_group in dis_optimizer.param_groups:
            print(param_group.keys())
            for param in param_group['params']:
                print(param.size())

        sys.stdout.flush()

        # Pre-train discriminator
        if params.pretrain_discriminator:
            print('============= Pre-training discriminator! ==============')
            pretrain_discriminator(params,
                                   generator,
                                   discriminator,
                                   dis_criterion,
                                   dis_optimizer,
                                   data,
                                   start_epoch=dis_epoch)
            print('========= Pre-training discriminator complete! =========')

        # Adversarial Training
        if params.adversarial_training:
            print('================ Adversarial training! =================')
            generator.build_optim()
            dis_criterion = nn.NLLLoss(reduction='mean')
            dis_optimizer = optim.Adam(discriminator.parameters())
            dis_criterion = dis_criterion.cuda()

            if params.adv_from_ckp and params.mle is not "mixed_mle":
                adv_ckp_path = os.path.join(params.logdir, params.adv_ckp)
                if params.fine_tune_bert:
                    epoch, batches, pos_in_batch, generator, discriminator, \
                        generator.trainer, dis_optimizer, \
                        generator.bert_trainer, _, _ = \
                        load_adv_ckp(
                            adv_ckp_path,
                            generator,
                            discriminator,
                            generator.trainer,
                            dis_optimizer,
                            generator.bert_trainer)
                else:
                    epoch, batches, pos_in_batch, generator, discriminator, \
                        generator.trainer, dis_optimizer, _, _, _ = \
                        load_adv_ckp(
                            adv_ckp_path,
                            generator,
                            discriminator,
                            generator.trainer,
                            dis_optimizer)
                adv_train(generator,
                          discriminator,
                          dis_criterion,
                          dis_optimizer,
                          data,
                          params,
                          start_epoch=epoch,
                          start_batches=batches,
                          start_pos_in_batch=pos_in_batch)

            elif params.adv_from_ckp and params.mle == "mixed_mle":
                adv_ckp_path = os.path.join(params.logdir, params.adv_ckp)
                if params.fine_tune_bert:
                    epoch, batches, pos_in_batch, generator, discriminator, \
                        generator.trainer, dis_optimizer, \
                        generator.bert_trainer, clamp, length = \
                        load_adv_ckp(
                            adv_ckp_path,
                            generator,
                            discriminator,
                            generator.trainer,
                            dis_optimizer,
                            generator.bert_trainer,
                            mle=True)
                else:
                    epoch, batches, pos_in_batch, generator, discriminator, \
                        generator.trainer, dis_optimizer, _, clamp, length = \
                        load_adv_ckp(
                            adv_ckp_path,
                            generator,
                            discriminator,
                            generator.trainer,
                            dis_optimizer,
                            mle=True)
                mixed_mle(generator,
                          discriminator,
                          dis_criterion,
                          dis_optimizer,
                          data,
                          params,
                          start_epoch=epoch,
                          start_batches=batches,
                          start_pos_in_batch=pos_in_batch,
                          start_clamp=clamp,
                          start_len=length)
            else:
                if params.mle == 'mixed_mle':
                    mixed_mle(generator, discriminator, dis_criterion,
                              dis_optimizer, data, params)
                else:
                    adv_train(generator, discriminator, dis_criterion,
                              dis_optimizer, data, params)

        if params.evaluate and 'valid' in params.evaluate_split:
            print("================== Evaluating! ===================")
            evaluate(generator, data, params, split='valid')
            print("============= Evaluation finished! ===============")
Ejemplo n.º 12
0
class Trainer:
    def __init__(self, nc=1, nz=100, ngf=64, ndf=64, lr=0.0002, beta1=0.5, ngpu=1, autosave=None):
        self.nz = nz
        self.dataloader = None
        self.img_list = None
        self.G_losses = None
        self.D_losses = None
        self.device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
        self.netG = Generator(nz, ngf, nc).to(self.device)
        self.netG.apply(utils.weights_init)
        self.netD = Discriminator(nc, ndf).to(self.device)
        self.netD.apply(utils.weights_init)
        self.criterion = nn.BCELoss()
        self.fixed_noise = torch.randn(64, nz, 1, 1, device=self.device)
        self.real_label = 1.
        self.fake_label = 0.
        self.optimizerD = optim.Adam(self.netD.parameters(), lr=lr, betas=(beta1, 0.999))
        self.optimizerG = optim.Adam(self.netG.parameters(), lr=lr, betas=(beta1, 0.999))
        self.result_path = autosave

    # Set the dataloader
    def load_data(self, dataloader):
        self.dataloader = dataloader
        print("Dataloader is prepared!")

    # Train the networks
    def train(self, num_epochs, render=False):

        # Check whether the dataloader is None, if so quit the training
        if self.dataloader is None:
            print("Data has not been loaded yet!")
            return

        # Else start the training
        print("Start training loop...")
        self.netG.apply(utils.weights_init)
        self.netD.apply(utils.weights_init)
        self.img_list = []
        self.G_losses = []
        self.D_losses = []

        for epoch in range(num_epochs):
            # For each batch in the dataloader
            for i, data in enumerate(self.dataloader, 0):

                # Train the discriminator with real images
                self.netD.zero_grad()
                real_cpu = data[0].to(self.device)
                b_size = real_cpu.size(0)
                label = torch.full((b_size,), self.real_label, dtype=torch.float, device=self.device)
                output = self.netD(real_cpu).view(-1)
                errD_real = self.criterion(output, label)
                errD_real.backward()
                D_x = output.mean().item()

                # Train the discriminator with fake images
                noise = torch.randn(b_size, self.nz, 1, 1, device=self.device)
                fake = self.netG(noise)
                label.fill_(self.fake_label)
                output = self.netD(fake.detach()).view(-1)
                errD_fake = self.criterion(output, label)
                errD_fake.backward()
                D_G_z1 = output.mean().item()

                # Update the parameters
                errD = errD_real + errD_fake
                self.optimizerD.step()

                # Train the generator
                self.netG.zero_grad()
                label.fill_(self.real_label)
                output = self.netD(fake).view(-1)
                errG = self.criterion(output, label)
                errG.backward()
                D_G_z2 = output.mean().item()
                self.optimizerG.step()

                if i % 100 == 0:
                    print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                          % (epoch, num_epochs, i, len(self.dataloader),
                             errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

                self.G_losses.append(errG.item())
                self.D_losses.append(errD.item())

            with torch.no_grad():
                # Fixed noise is employed here, because I want to generate images of same digits for more conspicuous comparisons
                fake = self.netG(self.fixed_noise).detach().cpu()
            self.img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

            # Plot the training result of this epoch
            self.draw_current_image(epoch, show=render)

    # Draw the original image
    def draw_original_image(self):
        real_batch = next(iter(self.dataloader))
        plt.figure(figsize=(8, 8))
        plt.axis("off")
        plt.title("Training Images")
        plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(self.device)[:64], padding=2, normalize=True).cpu(), (1, 2, 0)))
        if self.result_path is not None:
            plt.savefig(self.result_path + "\\origin.png")
        plt.show()

    # Plot the loss curves of both the generator and the discriminator
    def plot_loss(self):
        plt.figure(figsize=(10, 5))
        plt.title("Generator and Discriminator Loss During Training")
        plt.plot(self.G_losses, label="G")
        plt.plot(self.D_losses, label="D")
        plt.xlabel("iterations")
        plt.ylabel("Loss")
        plt.legend()
        if self.result_path is not None:
            plt.savefig(self.result_path + "\\loss.png")
        plt.show()

    # Draw the last 64 figures
    def draw_current_image(self, current_epoch, show=False):
        plt.figure(figsize=(8, 8))
        plt.axis("off")
        plt.title("Fake Images_" + str(current_epoch))
        plt.imshow(np.transpose(self.img_list[-1], (1, 2, 0)))
        if self.result_path is not None:
            plt.savefig(self.result_path + "\\fake_" + str(current_epoch) + ".png")
        if show:
            plt.show()
Ejemplo n.º 13
0
class RefSRSolver(BaseSolver):
    def __init__(self, cfg):
        super(RefSRSolver, self).__init__(cfg)

        self.srntt = SRNTT(cfg['model']['n_resblocks'],
                           cfg['schedule']['use_weights'],
                           cfg['schedule']['concat']).cuda()
        # self.discriminator = None
        self.discriminator = Discriminator(cfg['data']['input_size']).cuda()
        # self.vgg = None
        self.vgg = VGG19(cfg['model']['final_layer'],
                         cfg['model']['prev_layer'], True).cuda()
        params = list(self.srntt.texture_transfer.parameters()) + list(self.srntt.texture_fusion_medium.parameters()) +\
                 list(self.srntt.texture_fusion_large.parameters()) + list(self.srntt.srntt_out.parameters())
        self.init_epoch = self.cfg['schedule']['init_epoch']
        self.num_epochs = self.cfg['schedule']['num_epochs']
        self.optimizer_init = torch.optim.Adam(params,
                                               lr=cfg['schedule']['lr'])
        self.optimizer = torch.optim.lr_scheduler.MultiStepLR(
            torch.optim.Adam(params, lr=cfg['schedule']['lr']),
            [self.num_epochs // 2], 0.1)
        self.optimizer_d = torch.optim.lr_scheduler.MultiStepLR(
            torch.optim.Adam(self.discriminator.parameters(),
                             lr=cfg['schedule']['lr']), [self.num_epochs // 2],
            0.1)
        self.reconst_loss = nn.L1Loss()
        self.bp_loss = BackProjectionLoss()
        self.texture_loss = TextureLoss(self.cfg['schedule']['use_weights'],
                                        80)
        self.adv_loss = AdvLoss(self.cfg['schedule']['is_WGAN_GP'])
        self.loss_weights = self.cfg['schedule']['loss_weights']

    def train(self):
        if self.epoch <= self.init_epoch:
            with tqdm(total=len(self.train_loader),
                      miniters=1,
                      desc='Initial Training Epoch: [{}/{}]'.format(
                          self.epoch, self.max_epochs)) as t:
                for data in self.train_loader:
                    lr, hr = data['lr'].cuda(), data['hr'].cuda()
                    maps, weight = data['map'].cuda(), data['weight'].cuda()
                    self.srntt.train()
                    self.optimizer_init.zero_grad()
                    sr, srntt_out = self.srntt(lr, weight, maps)
                    loss_reconst = self.reconst_loss(sr, hr)
                    loss_bp = self.bp_loss(lr, srntt_out)
                    loss = self.loss_weights[
                        4] * loss_reconst + self.loss_weights[3] * loss_bp
                    t.set_postfix_str("Batch loss {:.4f}".format(loss.item()))
                    t.update()

                    loss.backward()
                    self.optimizer_init.step()
        elif self.epoch <= self.num_epochs:
            with tqdm(total=len(self.train_loader),
                      miniters=1,
                      desc='Complete Training Epoch: [{}/{}]'.format(
                          self.epoch, self.max_epochs)) as t:
                for data in self.train_loader:
                    lr, hr = data['lr'].cuda(), data['hr'].cuda()
                    maps, weight = data['map'].cuda(), data['weight'].cuda()
                    self.srntt.train()
                    self.optimizer_init.zero_grad()
                    self.optimizer.optimizer.zero_grad()
                    self.optimizer_d.optimizer.zero_grad()
                    sr, srntt_out = self.srntt(lr, weight, maps)
                    sr_prevlayer, sr_lastlayer = self.vgg(srntt_out)
                    hr_prevlayer, hr_lastlayer = self.vgg(hr)
                    _, d_real_logits = self.discriminator(hr)
                    _, d_fake_logits = self.discriminator(srntt_out)
                    loss_reconst = self.reconst_loss(sr, hr)
                    loss_bp = self.bp_loss(lr, srntt_out)
                    loss_texture = self.texture_loss(sr_prevlayer, maps,
                                                     weight)
                    loss_d, loss_g = self.adv_loss(srntt_out, hr,
                                                   d_fake_logits,
                                                   d_real_logits,
                                                   self.discriminator)
                    loss_percep = torch.pow(sr_lastlayer - hr_lastlayer,
                                            2).mean()
                    if self.cfg['schedule']['use_lower_layers_in_per_loss']:
                        for l_sr, l_hr in zip(sr_prevlayer, hr_prevlayer):
                            loss_percep += torch.pow(l_sr - l_hr, 2).mean()
                        loss_percep = loss_percep / (len(sr_prevlayer) + 1)
                    weighted_loss = torch.Tensor(self.loss_weights).cuda() * \
                                    torch.Tensor([loss_percep, loss_texture, loss_g, loss_bp, loss_reconst])
                    total_loss = weighted_loss.sum()

                    t.set_postfix_str("Batch loss {:.4f}".format(
                        total_loss.item()))
                    t.update()

                    loss_d.backward()
                    total_loss.backward()
                    self.optimizer.step(self.epoch)
                    self.optimizer_d.step(self.epoch)
        else:
            pass

    def eval(self):
        with tqdm(total=len(self.val_loader),
                  miniters=1,
                  desc='Val Epoch: [{}/{}]'.format(self.epoch,
                                                   self.max_epochs)) as t:
            psnr_list, ssim_list, loss_list = [], [], []
            for lr, hr in self.val_loader:
                lr, hr = lr.cuda(), hr.cuda()
                self.srntt.eval()
                with torch.no_grad():
                    sr, _ = self.srntt(lr, None, None)
                    loss = self.reconst_loss(sr, hr)

                batch_psnr, batch_ssim = [], []
                for c in range(sr.shape[0]):
                    predict_sr = (sr[c, ...].cpu().numpy().transpose(
                        (1, 2, 0)) + 1) * 127.5
                    ground_truth = (hr[c, ...].cpu().numpy().transpose(
                        (1, 2, 0)) + 1) * 127.5
                    psnr = utils.calculate_psnr(predict_sr, ground_truth, 255)
                    ssim = utils.calculate_ssim(predict_sr, ground_truth, 255)
                    batch_psnr.append(psnr)
                    batch_ssim.append(ssim)
                avg_psnr = np.array(batch_psnr).mean()
                avg_ssim = np.array(batch_ssim).mean()
                psnr_list.extend(batch_psnr)
                ssim_list.extend(batch_ssim)
                t.set_postfix_str(
                    'Batch loss: {:.4f}, PSNR: {:.4f}, SSIM: {:.4f}'.format(
                        loss.item(), avg_psnr, avg_ssim))
                t.update()
            self.records['Epoch'].append(self.epoch)
            self.records['PSNR'].append(np.array(psnr_list).mean())
            self.records['SSIM'].append(np.array(ssim_list).mean())
            self.logger.log('Val Epoch {}: PSNR={}, SSIM={}'.format(
                self.epoch, self.records['PSNR'][-1],
                self.records['SSIM'][-1]))

    def save_checkpoint(self):
        super(RefSRSolver, self).save_checkpoint()
        self.ckp['srntt'] = self.srntt.state_dict()
        self.ckp['optimizer'] = self.optimizer.state_dict()
        self.ckp['optimizer_d'] = self.optimizer_d.state_dict()
        self.ckp['optimizer_init'] = self.optimizer_init.state_dict()
        if self.discriminator is not None:
            self.ckp['discriminator'] = self.discriminator.state_dict()
        if self.vgg is not None:
            self.ckp['vgg'] = self.vgg.state_dict()

        torch.save(self.ckp, os.path.join(self.checkpoint_dir, 'latest.pth'))
        if self.records['PSNR'][-1] == np.array(self.records['PSNR']).max():
            shutil.copy(os.path.join(self.checkpoint_dir, 'latest.pth'),
                        os.path.join(self.checkpoint_dir, 'best.pth'))

    def load_checkpoint(self, model_path):
        super(RefSRSolver, self).load_checkpoint(model_path)
        ckpt = torch.load(model_path)
        self.srntt.load_state_dict(ckpt['srntt'])
        self.optimizer.load_state_dict(ckpt['optimizer'])
        self.optimizer_d.load_state_dict(ckpt['optimizer_d'])
        self.optimizer_init.load_state_dict(ckpt['optimizer_init'])
        if 'vgg' in ckpt.keys() and self.vgg is not None:
            self.vgg.load_stat_dict(ckpt['srntt'])
        if 'discriminator' in ckpt.keys() and self.discriminator is not None:
            self.discriminator.load_state_dict(ckpt['discriminator'])
Ejemplo n.º 14
0
class Trial:
    def __init__(self,
                 data_dir: str = './dataset',
                 log_dir: str = './logs',
                 device: str = "cuda:0",
                 batch_size: int = 2,
                 init_lr: float = 0.5,
                 G_lr: float = 0.0004,
                 D_lr: float = 0.0008,
                 level: str = "O1",
                 patch: bool = False,
                 init_training_epoch: int = 10,
                 train_epoch: int = 10,
                 optim_type: str = "ADAM",
                 pin_memory: bool = True,
                 grad_set_to_none: bool = True):

        # self.config = config
        self.data_dir = data_dir

        self.dataset = Dataset(root=data_dir + "/Shinkai",
                               style_transform=tr.transform,
                               smooth_transform=tr.transform)

        self.pin_memory = pin_memory
        self.batch_size = batch_size

        self.dataloader = DataLoader(self.dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=4,
                                     pin_memory=pin_memory)

        self.device = torch.device(
            device) if torch.cuda.is_available() else torch.device('cpu')

        self.G = Generator().to(self.device)
        self.patch = patch
        if self.patch:
            self.D = PatchDiscriminator().to(self.device)
        else:
            self.D = Discriminator().to(self.device)

        self.init_model_weights()

        self.optimizer_G = GANOptimizer(optim_type,
                                        self.G.parameters(),
                                        lr=G_lr,
                                        betas=(0.5, 0.999),
                                        amsgrad=False)
        self.optimizer_D = GANOptimizer(optim_type,
                                        self.D.parameters(),
                                        lr=D_lr,
                                        betas=(0.5, 0.999),
                                        amsgrad=True)

        self.loss = Loss(device=self.device).to(self.device)

        self.init_lr = init_lr
        self.G_lr = G_lr
        self.D_lr = D_lr
        self.grad_set_to_none = grad_set_to_none

        self.writer = tensorboard.SummaryWriter(log_dir=log_dir)
        self.init_train_epoch = init_training_epoch
        self.train_epoch = train_epoch

        self.init_time = None
        self.level = level

        if self.level != "O0" and device != "cpu":
            self.fp16 = True
            [self.G,
             self.D], [self.optimizer_G, self.optimizer_D
                       ] = amp.initialize([self.G, self.D],
                                          [self.optimizer_G, self.optimizer_D],
                                          opt_level=self.level)
        else:
            self.fp16 = False

    def init_model_weights(self):
        self.G.apply(weights_init)
        self.D.apply(weights_init)

    @classmethod
    def from_config(cls):
        pass

    def init_train(self, con_weight: float = 1.0):

        test_img = self.get_test_image()
        meter = AverageMeter("Loss")
        self.writer.flush()
        lr_scheduler = OneCycleLR(self.optimizer_G,
                                  max_lr=0.9999,
                                  steps_per_epoch=len(self.dataloader),
                                  epochs=self.init_train_epoch)

        for g in self.optimizer_G.param_groups:
            g['lr'] = self.init_lr

        for epoch in tqdm(range(self.init_train_epoch)):

            meter.reset()

            for i, (style, smooth, train) in enumerate(self.dataloader, 0):
                # train = transform(test_img).unsqueeze(0)
                self.G.zero_grad(set_to_none=self.grad_set_to_none)
                train = train.to(self.device)

                generator_output = self.G(train)
                # content_loss = loss.reconstruction_loss(generator_output, train) * con_weight
                content_loss = self.loss.content_loss(generator_output,
                                                      train) * con_weight
                # content_loss = F.mse_loss(train, generator_output) * con_weight
                content_loss.backward()
                self.optimizer_G.step()
                lr_scheduler.step()

                meter.update(content_loss.detach())

            self.writer.add_scalar(f"Loss : {self.init_time}",
                                   meter.sum.item(), epoch)
            self.write_weights(epoch + 1, write_D=False)
            self.eval_image(epoch, f'{self.init_time} reconstructed img',
                            test_img)

        for g in self.optimizer_G.param_groups:
            g['lr'] = self.G_lr

        # self.save_trial(self.init_train_epoch, "init")

    def eval_image(self, epoch: int, caption, img):
        """Feeds in one single image to process and save."""
        self.G.eval()
        styled_test_img = tr.transform(img).unsqueeze(0).to(self.device)
        with torch.no_grad():
            styled_test_img = self.G(styled_test_img)
            styled_test_img = styled_test_img.to('cpu').squeeze()
        self.write_image(styled_test_img, caption, epoch + 1)
        self.writer.flush()
        self.G.train()

    def write_image(self,
                    image: torch.Tensor,
                    img_caption: str = "sample_image",
                    step: int = 0):

        image = torch.clip(tr.inv_norm(image).to(torch.float), 0,
                           1)  # [-1, 1] -> [0, 1]
        image *= 255.  # [0, 1] -> [0, 255]
        image = image.permute(1, 2, 0).to(dtype=torch.uint8)
        self.writer.add_image(img_caption, image, step, dataformats='HWC')
        self.writer.flush()

    def write_weights(self, epoch: int, write_D=True, write_G=True):

        if write_D:
            for name, weight in self.D.named_parameters():
                if 'depthwise' in name or 'pointwise' in name:
                    self.writer.add_histogram(
                        f"Discriminator {name} {self.init_time}", weight,
                        epoch)
                    self.writer.add_histogram(
                        f"Discriminator {name}.grad {self.init_time}",
                        weight.grad, epoch)
                    self.writer.flush()

        if write_G:
            for name, weight in self.G.named_parameters():
                self.writer.add_histogram(f"Generator {name} {self.init_time}",
                                          weight, epoch)
                self.writer.add_histogram(
                    f"Generator {name}.grad {self.init_time}", weight.grad,
                    epoch)
                self.writer.flush()

    def train_1(
        self,
        adv_weight: float = 300.,
        con_weight: float = 1.5,
        gra_weight: float = 3.,
        col_weight: float = 10.,
    ):

        test_img_dir = Path(
            self.data_dir).joinpath('test/test_photo256').resolve()
        test_img_dir = random.choice(list(test_img_dir.glob('**/*')))
        test_img = Image.open(test_img_dir)
        self.writer.add_image(f'test image {self.init_time}',
                              np.asarray(test_img),
                              dataformats='HWC')
        self.writer.flush()

        for epoch in tqdm(range(self.train_epoch)):

            for i, (style, smooth, train) in enumerate(self.dataloader, 0):

                self.D.zero_grad()
                style = style.to(self.device)
                smooth = smooth.to(self.device)
                train = train.to(self.device)

                # style image to discriminator(Not Gram Matrix Loss)
                style_loss_value = self.D(style).view(-1)
                generator_output = self.G(train)
                # generated image to discriminator
                real_output = self.D(generator_output.detach()).view(-1)
                # greyscale_output = D(transforms.functional.rgb_to_grayscale(train, num_output_channels=3)).view(-1) #greyscale adversarial loss
                gray_train = tr.inv_gray_transform(train)
                greyscale_output = self.D(gray_train).view(-1)
                smoothed_loss = self.D(smooth).view(-1)  # smoothed image loss
                # loss_D_real = adversarial_loss(output, label)

                dis_adv_loss = adv_weight * (
                    torch.pow(style_loss_value - 1, 2).mean() +
                    torch.pow(real_output, 2).mean())
                dis_gray_loss = torch.pow(greyscale_output, 2).mean()
                dis_edge_loss = torch.pow(smoothed_loss, 2).mean()
                discriminator_loss = dis_adv_loss + dis_gray_loss + dis_edge_loss
                discriminator_loss.backward()
                self.optimizer_D.step()

                if i % 200 == 0 and i != 0:
                    self.writer.add_scalars(
                        f'{self.init_time} Discriminator losses', {
                            'adversarial loss': dis_adv_loss.item(),
                            'grayscale loss': dis_gray_loss.item(),
                            'edge loss': dis_edge_loss.item()
                        }, i + epoch * len(self.dataloader))
                    self.writer.flush()

                real_output = self.D(generator_output).view(-1)
                per_loss = self.loss.perceptual_loss(
                    train, generator_output)  # loss for G
                style_loss = self.loss.style_loss(generator_output, style)
                content_loss = self.loss.content_loss(generator_output, train)
                recon_loss = self.loss.reconstruction_loss(
                    generator_output, train)
                tv_loss = self.loss.tv_loss(generator_output)
                '''
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch, num_epoch, i, len(data_loader),
                      loss_D.item(), loss_G.item(), D_x, D_G_z1, D_G_z2))'''

                self.G.zero_grad()
                gen_adv_loss = adv_weight * torch.pow(real_output - 1,
                                                      2).mean()
                gen_con_loss = con_weight * content_loss
                gen_sty_loss = gra_weight * style_loss
                gen_rec_loss = col_weight * recon_loss
                gen_per_loss = per_loss
                gen_tv_loss = tv_loss
                generator_loss = gen_adv_loss + gen_con_loss + gen_sty_loss + gen_rec_loss + gen_per_loss
                generator_loss.backward()
                self.optimizer_G.step()

                if i % 200 == 0 and i != 0:

                    self.writer.add_scalars(
                        f'generator losses {self.init_time}', {
                            'adversarial loss': gen_adv_loss.item(),
                            'content loss': gen_con_loss.item(),
                            'style loss': gen_sty_loss.item(),
                            'reconstruction loss': gen_rec_loss.item(),
                            'perceptual loss': gen_per_loss.item()
                        }, i + epoch * len(self.dataloader))
                    self.writer.flush()

            self.write_weights(epoch + 1)
            self.eval_image(epoch, f'{self.init_time} style img', test_img)

    def train_2(self,
                adv_weight: float = 1.0,
                threshold: float = 3.,
                G_train_iter: int = 1,
                D_train_iter: int = 1
                ):  # if threshold is 0., set to half of adversarial loss

        test_img_dir = Path(self.data_dir).joinpath('test', 'test_photo256')
        test_img_dir = random.choice(list(test_img_dir.glob('**/*')))
        test_img = Image.open(test_img_dir)

        if self.init_time is None:
            self.init_time = datetime.datetime.now().strftime("%H:%M")

        self.writer.add_image(f'sample_image {self.init_time}',
                              np.asarray(test_img),
                              dataformats='HWC')
        self.writer.flush()

        perception_weight = 0.
        keep_constant = False

        for epoch in tqdm(range(self.train_epoch)):

            total_dis_loss = 0.

            for i, (style, smooth, train) in enumerate(self.dataloader, 0):

                self.D.zero_grad()

                train = train.to(self.device)
                style = style.to(self.device)
                # smooth = smooth.to(device)

                for _ in range(D_train_iter):
                    style_loss_value = self.D(style).view(-1)
                    generator_output = self.G(train)
                    real_output = self.D(generator_output.detach()).view(-1)
                    dis_adv_loss = adv_weight * \
                        (torch.pow(style_loss_value - 1, 2).mean() + torch.pow(real_output, 2).mean())
                    total_dis_loss += dis_adv_loss.item()
                    dis_adv_loss.backward()
                self.optimizer_D.step()

                self.G.zero_grad()
                for _ in range(G_train_iter):
                    generator_output = self.G(train)
                    real_output = self.D(generator_output).view(-1)
                    per_loss = perception_weight * \
                        self.loss.perceptual_loss(train, generator_output)
                    gen_adv_loss = adv_weight * torch.pow(real_output - 1,
                                                          2).mean()
                    gen_loss = gen_adv_loss + per_loss
                    gen_loss.backward()
                self.optimizer_G.step()

                if i % 200 == 0 and i != 0:
                    self.writer.add_scalars(
                        f'generator losses  {self.init_time}', {
                            'adversarial loss': dis_adv_loss.item(),
                            'Generator adversarial loss': gen_adv_loss.item(),
                            'perceptual loss': per_loss.item()
                        }, i + epoch * len(self.dataloader))
                    self.writer.flush()

            if total_dis_loss > threshold and not keep_constant:
                perception_weight += 0.05
            else:
                keep_constant = True

            self.writer.add_scalar(
                f'total discriminator loss {self.init_time}', total_dis_loss,
                i + epoch * len(self.dataloader))

            self.write_weights()
            self.G.eval()

            styled_test_img = tr.transform(test_img).unsqueeze(0).to(
                self.device)
            with torch.no_grad():
                styled_test_img = self.G(styled_test_img)

            styled_test_img = styled_test_img.to('cpu').squeeze()
            self.write_image(styled_test_img, f'styled image {self.init_time}',
                             epoch + 1)

            self.G.train()

    def __call__(self):
        self.init_train()
        self.train_1()

    def save_trial(self, epoch: int, train_type: str):
        save_dir = Path(f"{train_type}_{self.level}.pt")
        training_details = {
            "epoch": epoch,
            "gen": {
                "gen_state_dict": self.G.state_dict(),
                "optim_G_state_dict": self.optimizer_G.state_dict()
            },
            "dis": {
                "dis_state_dict": self.D.state_dict(),
                "optim_D_state_dict": self.optimizer_D.state_dict()
            }
        }
        if self.fp16:
            training_details["amp"] = amp.state_dict()

        torch.save(training_details, save_dir.as_posix())

    def load_trial(self, dir: Path):
        assert dir.is_file(), "No such directory"
        assert dir.suffix == ".pt", "Filetype not compatible"
        state_dicts = torch.load(dir.as_posix())
        self.G.load_state_dict(state_dicts["gen"]["gen_state_dict"])
        self.optimizer_G.load_state_dict(
            state_dicts["gen"]["optim_G_state_dict"])
        self.D.load_state_dict(state_dicts["dis"]["dis_state_dict"])
        self.optimizer_D.load_state_dict(
            state_dicts["dis"]["optim_D_state_dict"])
        if self.fp16:
            amp.load_state_dict(state_dicts["amp"])
        typer.echo("Loaded Weights")

    def Generator_NOGAN(self,
                        epochs: int = 1,
                        style_weight: float = 20.,
                        content_weight: float = 1.2,
                        recon_weight: float = 10.,
                        tv_weight: float = 1e-6,
                        loss: List[str] = ['content_loss']):
        """Training Generator in NOGAN manner (Feature Loss only)."""
        for g in self.optimizer_G.param_groups:
            g['lr'] = self.G_lr
        test_img = self.get_test_image()
        max_lr = self.G_lr * 10.

        lr_scheduler = OneCycleLR(self.optimizer_G,
                                  max_lr=max_lr,
                                  steps_per_epoch=len(self.dataloader),
                                  epochs=epochs)

        meter = LossMeters(*loss)
        total_loss_arr = np.array([])

        for epoch in tqdm(range(epochs)):

            total_losses = 0
            meter.reset()

            for i, (style, smooth, train) in enumerate(self.dataloader, 0):
                # train = transform(test_img).unsqueeze(0)
                self.G.zero_grad(set_to_none=self.grad_set_to_none)
                train = train.to(self.device)

                generator_output = self.G(train)
                if 'style_loss' in loss:
                    style = style.to(self.device)
                    style_loss = self.loss.style_loss(generator_output,
                                                      style) * style_weight
                else:
                    style_loss = 0.

                if 'content_loss' in loss:
                    content_loss = self.loss.content_loss(
                        generator_output, train) * content_weight
                else:
                    content_loss = 0.

                if 'recon_loss' in loss:
                    recon_loss = self.loss.reconstruction_loss(
                        generator_output, train) * recon_weight
                else:
                    recon_loss = 0.

                if 'tv_loss' in loss:
                    tv_loss = self.loss.tv_loss(generator_output) * tv_weight
                else:
                    tv_loss = 0.

                total_loss = content_loss + tv_loss + recon_loss + style_loss
                if self.fp16:
                    with amp.scale_loss(total_loss,
                                        self.optimizer_G) as scaled_loss:
                        scaled_loss.backward()
                else:
                    total_loss.backward()

                self.optimizer_G.step()
                lr_scheduler.step()
                total_losses += total_loss.detach()
                loss_dict = {
                    'content_loss': content_loss,
                    'style_loss': style_loss,
                    'recon_loss': recon_loss,
                    'tv_loss': tv_loss
                }

                losses = [loss_dict[loss_type].detach() for loss_type in loss]
                meter.update(*losses)

            total_loss_arr = np.append(total_loss_arr, total_losses.item())
            self.writer.add_scalars(f'{self.init_time} NOGAN generator losses',
                                    meter.as_dict('sum'), epoch)

            self.write_weights(epoch + 1, write_D=False)
            self.eval_image(epoch, f'{self.init_time} reconstructed img',
                            test_img)
            if epoch > 2:
                fig = plt.figure(figsize=(8, 8))
                X = np.arange(len(total_loss_arr))
                Y = np.gradient(total_loss_arr)
                plt.plot(X, Y)
                thresh = -1.0
                plt.axhline(thresh, c='r')
                plt.title(f"{self.init_time}")
                self.writer.add_figure(f"{self.init_time}", fig, epoch)
                if Y[-1] > thresh:
                    break

        self.save_trial(epoch, f'G_NG_{self.init_time}')

    def Discriminator_NOGAN(
            self,
            epochs: int = 3,
            adv_weight: float = 1.0,
            edge_weight: float = 1.0,
            loss: List[str] = ['real_adv_loss', 'fake_adv_loss', 'gray_loss']):
        """https://discuss.pytorch.org/t/scheduling-batch-size-in-dataloader/46443/2"""

        for g in self.optimizer_D.param_groups:
            g['lr'] = self.D_lr

        max_lr = self.D_lr * 10.
        lr_scheduler = OneCycleLR(self.optimizer_D,
                                  max_lr=max_lr,
                                  steps_per_epoch=len(self.dataloader),
                                  epochs=epochs)
        meter = LossMeters(*loss)
        total_loss_arr = np.array([])
        if self.init_time is None:
            self.init_time = datetime.datetime.now().strftime("%H:%M")

        for epoch in tqdm(range(epochs)):

            meter.reset()

            for i, (style, smooth, train) in enumerate(self.dataloader, 0):
                # train = transform(test_img).unsqueeze(0)
                self.D.zero_grad(set_to_none=self.grad_set_to_none)
                train = train.to(self.device)
                style = style.to(self.device)

                generator_output = self.G(train)
                real_adv_loss = self.D(style).view(-1)
                fake_adv_loss = self.D(generator_output.detach()).view(-1)
                real_adv_loss = torch.pow(real_adv_loss - 1,
                                          2).mean() * 1.7 * adv_weight
                fake_adv_loss = torch.pow(fake_adv_loss,
                                          2).mean() * 1.7 * adv_weight
                gray_train = tr.inv_gray_transform(style)
                greyscale_output = self.D(gray_train).view(-1)
                gray_loss = torch.pow(greyscale_output,
                                      2).mean() * 1.7 * adv_weight
                "According to AnimeGANv2 implementation, every loss is scaled by individual weights and then scaled with adv_weight"
                "https://github.com/TachibanaYoshino/AnimeGANv2/blob/5946b6afcca5fc28518b75a763c0f561ff5ce3d6/tools/ops.py#L217"
                total_loss = real_adv_loss + fake_adv_loss + gray_loss
                if self.fp16:
                    with amp.scale_loss(total_loss,
                                        self.optimizer_D) as scaled_loss:
                        scaled_loss.backward()
                else:
                    total_loss.backward()
                self.optimizer_D.step()
                lr_scheduler.step()

                loss_dict = {
                    'real_adv_loss': real_adv_loss,
                    'fake_adv_loss': fake_adv_loss,
                    'gray_loss': gray_loss
                }

                losses = [loss_dict[loss_type].detach() for loss_type in loss]
                meter.update(*losses)

            self.writer.add_scalars(
                f'{self.init_time} NOGAN discriminator loss',
                meter.as_dict('sum'), epoch)
            self.writer.flush()
            if epoch > 2:
                fig = plt.figure(figsize=(8, 8))
                X = np.arange(len(total_loss_arr))
                Y = np.gradient(total_loss_arr)
                plt.plot(X, Y)
                thresh = -1.0
                plt.axhline(thresh, c='r')
                plt.title(f"{self.init_time}")
                self.writer.add_figure(f"{self.init_time}", fig, epoch)
                if Y[-1] > thresh:
                    break

    def GAN_NOGAN(
        self,
        epochs: int = 1,
        GAN_G_lr: float = 0.00008,
        GAN_D_lr: float = 0.000016,
        D_loss: List[str] = [
            "real_adv_loss", "fake_adv_loss", "gray_loss", "edge_loss"
        ],
        adv_weight: float = 300.,
        edge_weight: float = 0.1,
        G_loss: List[str] = [
            "adv_loss", "content_loss", "style_loss", "recon_loss"
        ],
        style_weight: float = 20.,
        content_weight: float = 1.2,
        recon_weight: float = 10.,
        tv_weight: float = 1e-6,
    ):

        test_img = self.get_test_image()
        dis_meter = LossMeters(*D_loss)
        gen_meter = LossMeters(*G_loss)

        for g in self.optimizer_G.param_groups:
            g['lr'] = GAN_G_lr

        for g in self.optimizer_D.param_groups:
            g['lr'] = GAN_D_lr

        update_duration = len(self.dataloader) // 20

        for epoch in tqdm(range(epochs)):

            G_loss_arr = np.array([])
            dis_meter.reset()
            count = 0

            for i, (style, smooth, train) in enumerate(self.dataloader, 0):
                self.D.zero_grad(set_to_none=self.grad_set_to_none)
                train = train.to(self.device)
                style = style.to(self.device)
                smooth = smooth.to(self.device)

                generator_output = self.G(train)
                real_adv_loss = self.D(style).view(-1)
                fake_adv_loss = self.D(generator_output.detach()).view(-1)
                G_adv_loss = self.D(generator_output).view(-1)
                gray_train = tr.inv_gray_transform(style)
                grayscale_output = self.D(gray_train).view(-1)
                gray_smooth_data = tr.inv_gray_transform(smooth)
                smoothed_output = self.D(smooth).view(-1)

                real_adv_loss = torch.square(real_adv_loss -
                                             1.).mean() * 1.7 * adv_weight
                fake_adv_loss = torch.square(
                    fake_adv_loss).mean() * 1.7 * adv_weight
                gray_loss = torch.square(
                    grayscale_output).mean() * 1.7 * adv_weight
                edge_loss = torch.square(
                    smoothed_output).mean() * 1.0 * adv_weight

                total_D_loss = real_adv_loss + fake_adv_loss + gray_loss + edge_loss
                total_D_loss.backward()
                self.optimizer_D.step()

                D_loss_dict = {
                    'real_adv_loss': real_adv_loss,
                    'fake_adv_loss': fake_adv_loss,
                    'gray_loss': gray_loss,
                    'edge_loss': edge_loss
                }

                loss = list(D_loss_dict.values())

                dis_meter.update(*loss)

                if i % update_duration == 0 and i != 0:
                    self.writer.add_scalars(f'{self.init_time} NOGAN Dis loss',
                                            dis_meter.as_dict('val'),
                                            i + epoch * len(self.dataloader))
                    self.writer.flush()

                self.G.zero_grad(set_to_none=self.grad_set_to_none)
                G_adv_loss = torch.square(G_adv_loss - 1.).mean() * adv_weight

                if 'style_loss' in G_loss:
                    style_loss = self.loss.style_loss(generator_output,
                                                      style) * style_weight
                else:
                    style_loss = 0.

                if 'content_loss' in G_loss:
                    content_loss = self.loss.content_loss(
                        generator_output, train) * content_weight
                else:
                    content_loss = 0.

                if 'recon_loss' in G_loss:
                    recon_loss = self.loss.reconstruction_loss(
                        generator_output, train) * recon_weight
                else:
                    recon_loss = 0.

                if 'tv_loss' in G_loss:
                    tv_loss = self.loss.tv_loss(generator_output) * tv_weight
                else:
                    tv_loss = 0.

                total_G_loss = G_adv_loss + content_loss + tv_loss + recon_loss + style_loss
                total_G_loss.backward()
                self.optimizer_G.step()

                G_loss_dict = {
                    'adv_loss': G_adv_loss,
                    'content_loss': content_loss,
                    'style_loss': style_loss,
                    'recon_loss': recon_loss,
                    'tv_loss': tv_loss
                }

                losses = [
                    G_loss_dict[loss_type].detach() for loss_type in G_loss
                ]
                gen_meter.update(*losses)

                if i % update_duration == 0 and i != 0:
                    self.writer.add_scalars(f'{self.init_time} NOGAN Gen loss',
                                            gen_meter.as_dict('val'),
                                            i + epoch * len(self.dataloader))
                    self.writer.flush()
                    G_loss_arr = np.append(G_loss_arr, G_adv_loss.item())
                    self.eval_image(i + epoch * len(self.dataloader),
                                    f'{self.init_time} reconstructed img',
                                    test_img)

        self.save_trial(epoch, f'GAN_NG_{self.init_time}')

    def get_test_image(self):
        """Get random test image."""
        test_img_dir = Path(self.data_dir).joinpath('test/test_photo256')
        test_img_dir = random.choice(list(test_img_dir.glob('**/*')))
        test_img = Image.open(test_img_dir)
        self.init_time = datetime.datetime.now().strftime("%H:%M")
        self.writer.add_image(f'{self.init_time} sample_image',
                              np.asarray(test_img),
                              dataformats='HWC')
        self.writer.flush()
        return test_img
Ejemplo n.º 15
0
    }

    cvae = CVAE(opts.latent_size, device).to(device)
    dis = Discriminator().to(device)
    classifier = Classifier(opts.latent_size).to(device)
    classer = CLASSIFIERS().to(device)

    print(cvae)
    print(dis)
    print(classifier)

    optimizer_cvae = torch.optim.Adam(cvae.parameters(),
                                      lr=opts.lr,
                                      betas=(opts.b1, opts.b2),
                                      weight_decay=opts.weight_decay)
    optimizer_dis = torch.optim.Adam(dis.parameters(),
                                     lr=opts.lr,
                                     betas=(opts.b1, opts.b2),
                                     weight_decay=opts.weight_decay)
    optimizer_classifier = torch.optim.Adam(classifier.parameters(),
                                            lr=opts.lr,
                                            betas=(opts.b1, opts.b2),
                                            weight_decay=opts.weight_decay)

    i = 1
    while os.path.isdir('./ex/' + str(i)):
        i += 1
    os.mkdir('./ex/' + str(i))
    output_path = './ex/' + str(i)

    losses = {
Ejemplo n.º 16
0
    model_G.cuda()
    model_D.cuda()
    print('cuda is available!')

else:
    print('cuda is not available')

# パラメータ設定
# params_G = optim.Adam(model_G.parameters(),
#     lr=0.0002, betas=(0.5, 0.999))
# params_D = optim.Adam(model_D.parameters(),
#     lr=0.0002, betas=(0.5, 0.999))

params_G = optim.Adam(model_G.parameters(),
    lr=0.01)
params_D = optim.Adam(model_D.parameters(),
    lr=0.01)

# 潜在特徴100次元ベクトルz
nz = 100

# ロスを計算するときのラベル変数
if cuda:
    ones = torch.ones(batch_size).cuda() # 正例 1
    zeros = torch.zeros(batch_size).cuda() # 負例 0

loss_f = nn.BCEWithLogitsLoss()

# 途中結果の確認用の潜在特徴z
check_z = torch.randn(batch_size, nz, 1, 1).cuda()
Ejemplo n.º 17
0
def train():
    # Random Seed
    manual_seed = random.randint(1, 10000)
    print('Random Seed: ', manual_seed)
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)

    # Parameter
    dataroot = 'E:\\datasets\\ukiyoe-1024'
    workers = 2
    batch_size = 128
    image_size = 64
    nz = 100
    num_epochs = 100

    lr = 0.0002
    beta1 = 0.5
    ngpu = 1

    # create dataset and dataloader
    dataset = dset.ImageFolder(root=dataroot,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.CenterCrop(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers)

    # define device
    device = torch.device('cuda:0' if (
        torch.cuda.is_available() and ngpu > 0) else 'cpu')
    print('device: ', device)

    netG = Generator(ngpu).to(device)
    netG.apply(weight_init)

    netD = Discriminator(ngpu).to(device)
    netD.apply(weight_init)

    criterion = nn.BCELoss()
    fixed_noise = torch.randn(64, nz, 1, 1, device=device)
    real_label = 1.
    fake_label = 0.

    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))

    # training loop
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    print('Starting Training Loop.')
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            # Update D network
            netD.zero_grad()
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size, ),
                               real_label,
                               dtype=torch.float,
                               device=device)

            output = netD(real_cpu).view(-1)
            errD_real = criterion(output, label)

            errD_real.backward()
            D_x = output.mean().item()

            noise = torch.randn(b_size, nz, 1, 1, device=device)
            fake = netG(noise)
            label.fill_(fake_label)

            output = netD(fake.detach()).view(-1)
            errD_fake = criterion(output, label)

            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake

            optimizerD.step()

            # Update G network
            netG.zero_grad()
            label.fill_(real_label)

            output = netD(fake).view(-1)
            errG = criterion(output, label)

            errG.backward()
            D_G_z2 = output.mean().item()

            optimizerG.step()

            # Output training stats
            if i % 50 == 0:
                print(
                    '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch + 1, num_epochs, i, len(dataloader), errD.item(),
                       errG.item(), D_x, D_G_z1, D_G_z2))

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing by saving G's output on fixed_noise
            if (iters % 100 == 0) or ((epoch == num_epochs - 1) and
                                      (i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()
                img_list.append(
                    vutils.make_grid(fake, padding=2, normalize=True))

            iters += 1

        save_g_images(epoch, img_list)

    model_path_G = '.\\output\\model\\generator.pth'
    model_path_D = '.\\output\\model\\discriminator.pth'
    torch.save(netG, model_path_G)
    torch.save(netD, model_path_D)

    print('Finish training.')
disc.apply(init_weights)
disc.to(device)

#initialize the gan,content,perceptual and brightness loss
gan_criterion = nn.BCELoss().to(device)
content_criterion = nn.L1Loss().to(device)
perceptual_criterion = nn.MSELoss().to(device)
brightness_criterion = nn.L1Loss().to(device)

#set the feature extractor to evaluation mode as it will be used only to calculate perceptual loss
feature_extractor = FeatureExtractor()
feature_extractor.eval()
feature_extractor.to(device)

#initialize the optimizers for Generator and Discriminator
optimizerD = optim.Adam(disc.parameters(), lr=0.0003, betas=(0.5, 0.999))
optimizerG = optim.Adam(gen.parameters(), lr=0.0001, betas=(0.5, 0.999))

alpha = 0.5
beta = 1.8
gamma = 1.97
delta = 0.069

resume_epoch = 0

for e in range(resume_epoch, epochs):
    for i, data in enumerate(train_loader):

        hazy_images, clear_images = data

        #to prevent accumulation of gradients
Ejemplo n.º 19
0
    try:
        if not os.path.exists(directory):
            os.makedirs(directory)
            print("Create directory: " + directory)
    except OSError:
        print('Error: Creating directory. ' + directory)
    torch.save(model.state_dict(),
               "ignore/weights/%s/%s_s.pth" % (dataset, tb, mode))


# init optimizer and scheduler
print(" " * 75, "\r", "Loading optimizer...", end="\r")
optimizer_G = torch.optim.SGD(itertools.chain(G_model_AtoB.parameters(),
                                              G_model_BtoA.parameters()),
                              lr=args.lr)  #, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.SGD(D_model_A.parameters(),
                                lr=args.lr)  #, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.SGD(D_model_B.parameters(),
                                lr=args.lr)  #, betas=(0.5, 0.999))
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G,
                                                   lr_lambda=LambdaLR(
                                                       args.epochs, 0,
                                                       100).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A,
                                                     lr_lambda=LambdaLR(
                                                         args.epochs, 0,
                                                         100).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B,
                                                     lr_lambda=LambdaLR(
                                                         args.epochs, 0,
                                                         100).step)
Ejemplo n.º 20
0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

generator = Generator(512, 512).to(device)
discriminator = Discriminator().to(device)

g_optimizer = torch.optim.Adam(
    [{
        'params': generator.generator_mapping.parameters(),
        'lr': 0.001 * 0.01
    }, {
        'params': generator.generator_synth.parameters()
    }],
    lr=0.001,
    betas=(0., 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(),
                               lr=0.001,
                               betas=(0., 0.99))

############

summ_counter = 0
mean_losses = np.zeros(5)
batch_sizes = [256, 128, 64, 32, 16, 8]
epoch_sizes = [2, 4, 4, 8, 8, 16]
latent_const = torch.from_numpy(np.load('randn.npy')).float().to(device)

transform = transforms.Compose([
    transforms.CenterCrop([178, 178]),
    transforms.Resize([128, 128]),
    transforms.RandomHorizontalFlip(0.5),
Ejemplo n.º 21
0
def main():
    ## load std models
    # policy_log_std = torch.load('./model_pkl/policy_net_action_std_model_1.pkl')
    # transition_log_std = torch.load('./model_pkl/transition_net_state_std_model_1.pkl')

    # load expert data
    print(args.data_set_path)
    dataset = ExpertDataSet(args.data_set_path)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=args.expert_batch_size,
                                  shuffle=True,
                                  num_workers=0)
    # define actor/critic/discriminator net and optimizer
    policy = Policy(onehot_action_sections,
                    onehot_state_sections,
                    state_0=dataset.state)
    value = Value()
    discriminator = Discriminator()
    optimizer_policy = torch.optim.Adam(policy.parameters(), lr=args.policy_lr)
    optimizer_value = torch.optim.Adam(value.parameters(), lr=args.value_lr)
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=args.discrim_lr)
    discriminator_criterion = nn.BCELoss()
    if write_scalar:
        writer = SummaryWriter(log_dir='runs/' + model_name)

    # load net  models
    if load_model:
        discriminator.load_state_dict(
            torch.load('./model_pkl/Discriminator_model_' + model_name +
                       '.pkl'))
        policy.transition_net.load_state_dict(
            torch.load('./model_pkl/Transition_model_' + model_name + '.pkl'))
        policy.policy_net.load_state_dict(
            torch.load('./model_pkl/Policy_model_' + model_name + '.pkl'))
        value.load_state_dict(
            torch.load('./model_pkl/Value_model_' + model_name + '.pkl'))

        policy.policy_net_action_std = torch.load(
            './model_pkl/Policy_net_action_std_model_' + model_name + '.pkl')
        policy.transition_net_state_std = torch.load(
            './model_pkl/Transition_net_state_std_model_' + model_name +
            '.pkl')
    print('#############  start training  ##############')

    # update discriminator
    num = 0
    for ep in tqdm(range(args.training_epochs)):
        # collect data from environment for ppo update
        policy.train()
        value.train()
        discriminator.train()
        start_time = time.time()
        memory, n_trajs = policy.collect_samples(
            batch_size=args.sample_batch_size)
        # print('sample_data_time:{}'.format(time.time()-start_time))
        batch = memory.sample()
        onehot_state = torch.cat(batch.onehot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        multihot_state = torch.cat(batch.multihot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        continuous_state = torch.cat(batch.continuous_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()

        onehot_action = torch.cat(batch.onehot_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        multihot_action = torch.cat(batch.multihot_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        continuous_action = torch.cat(batch.continuous_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        next_onehot_state = torch.cat(batch.next_onehot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        next_multihot_state = torch.cat(batch.next_multihot_state,
                                        dim=1).reshape(
                                            n_trajs * args.sample_traj_length,
                                            -1).detach()
        next_continuous_state = torch.cat(
            batch.next_continuous_state,
            dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach()

        old_log_prob = torch.cat(batch.old_log_prob, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        mask = torch.cat(batch.mask,
                         dim=1).reshape(n_trajs * args.sample_traj_length,
                                        -1).detach()
        gen_state = torch.cat((onehot_state, multihot_state, continuous_state),
                              dim=-1)
        gen_action = torch.cat(
            (onehot_action, multihot_action, continuous_action), dim=-1)
        if ep % 1 == 0:
            # if (d_slow_flag and ep % 50 == 0) or (not d_slow_flag and ep % 1 == 0):
            d_loss = torch.empty(0, device=device)
            p_loss = torch.empty(0, device=device)
            v_loss = torch.empty(0, device=device)
            gen_r = torch.empty(0, device=device)
            expert_r = torch.empty(0, device=device)
            for expert_state_batch, expert_action_batch in data_loader:
                noise1 = torch.normal(0,
                                      args.noise_std,
                                      size=gen_state.shape,
                                      device=device)
                noise2 = torch.normal(0,
                                      args.noise_std,
                                      size=gen_action.shape,
                                      device=device)
                noise3 = torch.normal(0,
                                      args.noise_std,
                                      size=expert_state_batch.shape,
                                      device=device)
                noise4 = torch.normal(0,
                                      args.noise_std,
                                      size=expert_action_batch.shape,
                                      device=device)
                gen_r = discriminator(gen_state + noise1, gen_action + noise2)
                expert_r = discriminator(
                    expert_state_batch.to(device) + noise3,
                    expert_action_batch.to(device) + noise4)

                # gen_r = discriminator(gen_state, gen_action)
                # expert_r = discriminator(expert_state_batch.to(device), expert_action_batch.to(device))
                optimizer_discriminator.zero_grad()
                d_loss = discriminator_criterion(gen_r, torch.zeros(gen_r.shape, device=device)) + \
                            discriminator_criterion(expert_r,torch.ones(expert_r.shape, device=device))
                variance = 0.5 * torch.var(gen_r.to(device)) + 0.5 * torch.var(
                    expert_r.to(device))
                total_d_loss = d_loss - 10 * variance
                d_loss.backward()
                # total_d_loss.backward()
                optimizer_discriminator.step()
            if write_scalar:
                writer.add_scalar('d_loss', d_loss, ep)
                writer.add_scalar('total_d_loss', total_d_loss, ep)
                writer.add_scalar('variance', 10 * variance, ep)
        if ep % 1 == 0:
            # update PPO
            noise1 = torch.normal(0,
                                  args.noise_std,
                                  size=gen_state.shape,
                                  device=device)
            noise2 = torch.normal(0,
                                  args.noise_std,
                                  size=gen_action.shape,
                                  device=device)
            gen_r = discriminator(gen_state + noise1, gen_action + noise2)
            #if gen_r.mean().item() < 0.1:
            #    d_stop = True
            #if d_stop and gen_r.mean()
            optimize_iter_num = int(
                math.ceil(onehot_state.shape[0] / args.ppo_mini_batch_size))
            # gen_r = -(1 - gen_r + 1e-10).log()
            for ppo_ep in range(args.ppo_optim_epoch):
                for i in range(optimize_iter_num):
                    num += 1
                    index = slice(
                        i * args.ppo_mini_batch_size,
                        min((i + 1) * args.ppo_mini_batch_size,
                            onehot_state.shape[0]))
                    onehot_state_batch, multihot_state_batch, continuous_state_batch, onehot_action_batch, multihot_action_batch, continuous_action_batch, \
                    old_log_prob_batch, mask_batch, next_onehot_state_batch, next_multihot_state_batch, next_continuous_state_batch, gen_r_batch = \
                        onehot_state[index], multihot_state[index], continuous_state[index], onehot_action[index], multihot_action[index], continuous_action[index], \
                        old_log_prob[index], mask[index], next_onehot_state[index], next_multihot_state[index], next_continuous_state[index], gen_r[
                            index]
                    v_loss, p_loss = ppo_step(
                        policy, value, optimizer_policy, optimizer_value,
                        onehot_state_batch, multihot_state_batch,
                        continuous_state_batch, onehot_action_batch,
                        multihot_action_batch, continuous_action_batch,
                        next_onehot_state_batch, next_multihot_state_batch,
                        next_continuous_state_batch, gen_r_batch,
                        old_log_prob_batch, mask_batch, args.ppo_clip_epsilon)
                    if write_scalar:
                        writer.add_scalar('p_loss', p_loss, ep)
                        writer.add_scalar('v_loss', v_loss, ep)
        policy.eval()
        value.eval()
        discriminator.eval()
        noise1 = torch.normal(0,
                              args.noise_std,
                              size=gen_state.shape,
                              device=device)
        noise2 = torch.normal(0,
                              args.noise_std,
                              size=gen_action.shape,
                              device=device)
        gen_r = discriminator(gen_state + noise1, gen_action + noise2)
        expert_r = discriminator(
            expert_state_batch.to(device) + noise3,
            expert_action_batch.to(device) + noise4)
        gen_r_noise = gen_r.mean().item()
        expert_r_noise = expert_r.mean().item()
        gen_r = discriminator(gen_state, gen_action)
        expert_r = discriminator(expert_state_batch.to(device),
                                 expert_action_batch.to(device))
        if write_scalar:
            writer.add_scalar('gen_r', gen_r.mean(), ep)
            writer.add_scalar('expert_r', expert_r.mean(), ep)
            writer.add_scalar('gen_r_noise', gen_r_noise, ep)
            writer.add_scalar('expert_r_noise', expert_r_noise, ep)
        print('#' * 5 + 'training episode:{}'.format(ep) + '#' * 5)
        print('gen_r_noise', gen_r_noise)
        print('expert_r_noise', expert_r_noise)
        print('gen_r:', gen_r.mean().item())
        print('expert_r:', expert_r.mean().item())
        print('d_loss', d_loss.item())
        # save models
        if model_name is not None:
            torch.save(
                discriminator.state_dict(),
                './model_pkl/Discriminator_model_' + model_name + '.pkl')
            torch.save(policy.transition_net.state_dict(),
                       './model_pkl/Transition_model_' + model_name + '.pkl')
            torch.save(policy.policy_net.state_dict(),
                       './model_pkl/Policy_model_' + model_name + '.pkl')
            torch.save(
                policy.policy_net_action_std,
                './model_pkl/Policy_net_action_std_model_' + model_name +
                '.pkl')
            torch.save(
                policy.transition_net_state_std,
                './model_pkl/Transition_net_state_std_model_' + model_name +
                '.pkl')
            torch.save(value.state_dict(),
                       './model_pkl/Value_model_' + model_name + '.pkl')
        memory.clear_memory()
def main(data_dir):
    # 0) Tensoboard Writer.
    writer = SummaryWriter(FLAGS['summary_path'])
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToResize((416, 416)),
                              ToNormalize(FLAGS["normalize_mean"],
                                          FLAGS["normalize_std"])
                          ]))

    wlp300_dataloader = DataLoader(dataset=wlp300,
                                   batch_size=FLAGS['batch_size'],
                                   shuffle=True,
                                   num_workers=1)

    # 2) Intermediate Processing.
    transform_img = transforms.Compose([
        transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])
    ])

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    model = ResFCN256(resolution_input=416,
                      resolution_output=416,
                      channel=3,
                      size=16)
    discriminator = Discriminator()

    # Load the pre-trained weight
    if FLAGS['resume'] != "" and os.path.exists(
            os.path.join(FLAGS['pretrained'], FLAGS['resume'])):
        state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume']))
        model.load_state_dict(state['prnet'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to(FLAGS["device"])
    discriminator.to(FLAGS["device"])

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=FLAGS["lr"],
                                 betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=FLAGS["lr"])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])
    bce_loss = torch.nn.BCEWithLogitsLoss()
    bce_loss.to(FLAGS["device"])

    #Loss function for adversarial
    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        loss_list_G, stat_list = [], []
        loss_list_D = []
        for i, sample in enumerate(bar):
            uv_map, origin = sample['uv_map'].to(
                FLAGS['device']), sample['origin'].to(FLAGS['device'])

            # Inference.
            optimizer.zero_grad()
            uv_map_result = model(origin)

            # Update D
            optimizer_D.zero_grad()
            fake_detach = uv_map_result.detach()
            d_fake = discriminator(fake_detach)
            d_real = discriminator(uv_map)
            retain_graph = False
            if FLAGS['gan_type'] == 'GAN':
                loss_d = bce_loss(d_real, d_fake)
            elif FLAGS['gan_type'].find('WGAN') >= 0:
                loss_d = (d_fake - d_real).mean()
                if FLAGS['gan_type'].find('GP') >= 0:
                    epsilon = torch.rand(fake_detach.shape[0]).view(
                        -1, 1, 1, 1)
                    epsilon = epsilon.to(fake_detach.device)
                    hat = fake_detach.mul(1 - epsilon) + uv_map.mul(epsilon)
                    hat.requires_grad = True
                    d_hat = discriminator(hat)
                    gradients = torch.autograd.grad(outputs=d_hat.sum(),
                                                    inputs=hat,
                                                    retain_graph=True,
                                                    create_graph=True,
                                                    only_inputs=True)[0]
                    gradients = gradients.view(gradients.size(0), -1)
                    gradient_norm = gradients.norm(2, dim=1)
                    gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
                    loss_d += gradient_penalty
            # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
            elif FLAGS['gan_type'] == 'RGAN':
                better_real = d_real - d_fake.mean(dim=0, keepdim=True)
                better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
                loss_d = bce_loss(better_real, better_fake)
                retain_graph = True

            if discriminator.training:
                loss_list_D.append(loss_d.item())
                loss_d.backward(retain_graph=retain_graph)
                optimizer_D.step()

                if 'WGAN' in FLAGS['gan_type']:
                    for p in discriminator.parameters():
                        p.data.clamp_(-1, 1)

            # Update G
            d_fake_bp = discriminator(
                uv_map_result)  # for backpropagation, use fake as it is
            if FLAGS['gan_type'] == 'GAN':
                label_real = torch.ones_like(d_fake_bp)
                loss_g = bce_loss(d_fake_bp, label_real)
            elif FLAGS['gan_type'].find('WGAN') >= 0:
                loss_g = -d_fake_bp.mean()
            elif FLAGS['gan_type'] == 'RGAN':
                better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
                better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
                loss_g = bce_loss(better_fake, better_real)

            loss_g.backward()
            loss_list_G.append(loss_g.item())
            optimizer.step()

            stat_logit = stat_loss(uv_map_result, uv_map)
            stat_list.append(stat_logit.item())
            #bar.set_description(" {} [Loss(Paper)] {} [Loss(D)] {} [SSIM({})] {}".format(ep, loss_list_G[-1], loss_list_D[-1],FLAGS["gauss_kernel"], stat_list[-1]))
            # Record Training information in Tensorboard.
            """
            if origin_img is None and uv_map_gt is None:
                origin_img, uv_map_gt = origin, uv_map
            uv_map_predicted = uv_map_result

            writer.add_scalar("Original Loss", loss_list_G[-1], FLAGS["summary_step"])
            writer.add_scalar("D Loss", loss_list_D[-1], FLAGS["summary_step"])
            writer.add_scalar("SSIM Loss", stat_list[-1], FLAGS["summary_step"])

            grid_1, grid_2, grid_3 = make_grid(origin_img, normalize=True), make_grid(uv_map_gt), make_grid(uv_map_predicted)

            writer.add_image('original', grid_1, FLAGS["summary_step"])
            writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"])
            writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"])
            writer.add_graph(model, uv_map)
            """

        if ep % FLAGS["save_interval"] == 0:

            with torch.no_grad():
                print(" {} [Loss(Paper)] {} [Loss(D)] {} [SSIM({})] {}".format(
                    ep, loss_list_G[-1], loss_list_D[-1],
                    FLAGS["gauss_kernel"], stat_list[-1]))
                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = np.load("./test_data/test_obama.npy")
                origin, gt_uv_map = test_data_preprocess(
                    origin), test_data_preprocess(gt_uv_map)

                origin, gt_uv_map = transform_img(origin), transform_img(
                    gt_uv_map)

                origin_in = origin.unsqueeze_(0).cuda()
                pred_uv_map = model(origin_in).detach().cpu()

                save_image(
                    [origin.cpu(),
                     gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                    os.path.join(FLAGS['images'],
                                 str(ep) + '.png'),
                    nrow=1,
                    normalize=True)

            # Save model
            print("Save model")
            state = {
                'prnet': model.state_dict(),
                'Loss': loss_list_G,
                'start_epoch': ep,
                'Loss_D': loss_list_D,
            }
            torch.save(state, os.path.join(FLAGS['images'],
                                           '{}.pth'.format(ep)))

            scheduler.step()

    writer.close()
Ejemplo n.º 23
0
class Trainer(nn.Module):
    def __init__(self, model_dir, g_optimizer, d_optimizer, lr, warmup,
                 max_iters):
        super().__init__()
        self.model_dir = model_dir
        if not os.path.exists(f'checkpoints/{model_dir}'):
            os.makedirs(f'checkpoints/{model_dir}')
        self.logs_dir = f'checkpoints/{model_dir}/logs'
        if not os.path.exists(self.logs_dir):
            os.makedirs(self.logs_dir)
        self.writer = SummaryWriter(self.logs_dir)

        self.arcface = ArcFaceNet(50, 0.6, 'ir_se').cuda()
        self.arcface.eval()
        self.arcface.load_state_dict(torch.load(
            'checkpoints/model_ir_se50.pth', map_location='cuda'),
                                     strict=False)

        self.mobiface = MobileFaceNet(512).cuda()
        self.mobiface.eval()
        self.mobiface.load_state_dict(torch.load(
            'checkpoints/mobilefacenet.pth', map_location='cuda'),
                                      strict=False)

        self.generator = Generator().cuda()
        self.discriminator = Discriminator().cuda()

        self.adversarial_weight = 1
        self.src_id_weight = 5
        self.tgt_id_weight = 1
        self.attributes_weight = 10
        self.reconstruction_weight = 10

        self.lr = lr
        self.warmup = warmup
        self.g_optimizer = g_optimizer(self.generator.parameters(),
                                       lr=lr,
                                       betas=(0, 0.999))
        self.d_optimizer = d_optimizer(self.discriminator.parameters(),
                                       lr=lr,
                                       betas=(0, 0.999))

        self.generator, self.g_optimizer = amp.initialize(self.generator,
                                                          self.g_optimizer,
                                                          opt_level="O1")
        self.discriminator, self.d_optimizer = amp.initialize(
            self.discriminator, self.d_optimizer, opt_level="O1")

        self._iter = nn.Parameter(torch.tensor(1), requires_grad=False)
        self.max_iters = max_iters

        if torch.cuda.is_available():
            self.cuda()

    @property
    def iter(self):
        return self._iter.item()

    @property
    def device(self):
        return next(self.parameters()).device

    def adapt(self, args):
        device = self.device
        return [arg.to(device) for arg in args]

    def train_loop(self, dataloaders, eval_every, generate_every, save_every):
        for batch in tqdm(dataloaders['train']):
            torch.Tensor.add_(self._iter, 1)
            # generator step
            # if self.iter % 2 == 0:
            # self.adjust_lr(self.g_optimizer)
            g_losses = self.g_step(self.adapt(batch))
            g_stats = self.get_opt_stats(self.g_optimizer, type='generator')
            self.write_logs(losses=g_losses, stats=g_stats, type='generator')

            # #discriminator step
            # if self.iter % 2 == 1:
            # self.adjust_lr(self.d_optimizer)
            d_losses = self.d_step(self.adapt(batch))
            d_stats = self.get_opt_stats(self.d_optimizer,
                                         type='discriminator')
            self.write_logs(losses=d_losses,
                            stats=d_stats,
                            type='discriminator')

            if self.iter % eval_every == 0:
                discriminator_acc = self.evaluate_discriminator_accuracy(
                    dataloaders['val'])
                identification_acc = self.evaluate_identification_similarity(
                    dataloaders['val'])
                metrics = {**discriminator_acc, **identification_acc}
                self.write_logs(metrics=metrics)

            if self.iter % generate_every == 0:
                self.generate(*self.adapt(batch))

            if self.iter % save_every == 0:
                self.save_discriminator()
                self.save_generator()

    def g_step(self, batch):
        self.generator.train()
        self.g_optimizer.zero_grad()
        L_adv, L_src_id, L_tgt_id, L_attr, L_rec, L_generator = self.g_loss(
            *batch)
        with amp.scale_loss(L_generator, self.g_optimizer) as scaled_loss:
            scaled_loss.backward()
        self.g_optimizer.step()

        losses = {
            'adv': L_adv.item(),
            'src_id': L_src_id.item(),
            'tgt_id': L_tgt_id.item(),
            'attributes': L_attr.item(),
            'reconstruction': L_rec.item(),
            'total_loss': L_generator.item()
        }
        return losses

    def d_step(self, batch):
        self.discriminator.train()
        self.d_optimizer.zero_grad()
        L_fake, L_real, L_discriminator = self.d_loss(*batch)
        with amp.scale_loss(L_discriminator, self.d_optimizer) as scaled_loss:
            scaled_loss.backward()
        self.d_optimizer.step()

        losses = {
            'hinge_fake': L_fake.item(),
            'hinge_real': L_real.item(),
            'total_loss': L_discriminator.item()
        }
        return losses

    def g_loss(self, Xs, Xt, same_person):
        with torch.no_grad():
            src_embed = self.arcface(
                F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))
            tgt_embed = self.arcface(
                F.interpolate(Xt[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))

        Y_hat, Xt_attr = self.generator(Xt, src_embed, return_attributes=True)

        Di = self.discriminator(Y_hat)

        L_adv = 0
        for di in Di:
            L_adv += hinge_loss(di[0], True)

        fake_embed = self.arcface(
            F.interpolate(Y_hat[:, :, 19:237, 19:237], [112, 112],
                          mode='bilinear',
                          align_corners=True))
        L_src_id = (
            1 - torch.cosine_similarity(src_embed, fake_embed, dim=1)).mean()
        L_tgt_id = (
            1 - torch.cosine_similarity(tgt_embed, fake_embed, dim=1)).mean()

        batch_size = Xs.shape[0]
        Y_hat_attr = self.generator.get_attr(Y_hat)
        L_attr = 0
        for i in range(len(Xt_attr)):
            L_attr += torch.mean(torch.pow(Xt_attr[i] - Y_hat_attr[i],
                                           2).reshape(batch_size, -1),
                                 dim=1).mean()
        L_attr /= 2.0

        L_rec = torch.sum(
            0.5 * torch.mean(torch.pow(Y_hat - Xt, 2).reshape(batch_size, -1),
                             dim=1) * same_person) / (same_person.sum() + 1e-6)
        L_generator = (self.adversarial_weight *
                       L_adv) + (self.src_id_weight * L_src_id) + (
                           self.tgt_id_weight *
                           L_tgt_id) + (self.attributes_weight * L_attr) + (
                               self.reconstruction_weight * L_rec)
        return L_adv, L_src_id, L_tgt_id, L_attr, L_rec, L_generator

    def d_loss(self, Xs, Xt, same_person):
        with torch.no_grad():
            src_embed = self.arcface(
                F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))
        Y_hat = self.generator(Xt, src_embed, return_attributes=False)

        fake_D = self.discriminator(Y_hat.detach())
        L_fake = 0
        for di in fake_D:
            L_fake += hinge_loss(di[0], False)
        real_D = self.discriminator(Xs)
        L_real = 0
        for di in real_D:
            L_real += hinge_loss(di[0], True)

        L_discriminator = 0.5 * (L_real + L_fake)
        return L_fake, L_real, L_discriminator

    def evaluate_discriminator_accuracy(self, val_dataloader):
        real_acc = 0
        fake_acc = 0
        self.generator.eval()
        self.discriminator.eval()
        for batch in tqdm(val_dataloader):
            Xs, Xt, _ = self.adapt(batch)

            with torch.no_grad():
                embed = self.arcface(
                    F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                Y_hat = self.generator(Xt, embed, return_attributes=False)
                fake_D = self.discriminator(Y_hat)
                real_D = self.discriminator(Xs)

            fake_multiscale_acc = 0
            for di in fake_D:
                fake_multiscale_acc += torch.mean((di[0] < 0).float())
            fake_acc += fake_multiscale_acc / len(fake_D)

            real_multiscale_acc = 0
            for di in real_D:
                real_multiscale_acc += torch.mean((di[0] > 0).float())
            real_acc += real_multiscale_acc / len(real_D)

        self.generator.train()
        self.discriminator.train()

        metrics = {
            'fake_acc': 100 * (fake_acc / len(val_dataloader)).item(),
            'real_acc': 100 * (real_acc / len(val_dataloader)).item()
        }
        return metrics

    def evaluate_identification_similarity(self, val_dataloader):
        src_id_sim = 0
        tgt_id_sim = 0
        self.generator.eval()
        for batch in tqdm(val_dataloader):
            Xs, Xt, _ = self.adapt(batch)
            with torch.no_grad():
                src_embed = self.arcface(
                    F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                Y_hat = self.generator(Xt, src_embed, return_attributes=False)

                src_embed = self.mobiface(
                    F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                tgt_embed = self.mobiface(
                    F.interpolate(Xt[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                fake_embed = self.mobiface(
                    F.interpolate(Y_hat[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))

            src_id_sim += (torch.cosine_similarity(src_embed,
                                                   fake_embed,
                                                   dim=1)).float().mean()
            tgt_id_sim += (torch.cosine_similarity(tgt_embed,
                                                   fake_embed,
                                                   dim=1)).float().mean()

        self.generator.train()

        metrics = {
            'src_similarity': 100 * (src_id_sim / len(val_dataloader)).item(),
            'tgt_similarity': 100 * (tgt_id_sim / len(val_dataloader)).item()
        }
        return metrics

    def generate(self, Xs, Xt, same_person):
        def get_grid_image(X):
            X = X[:8]
            X = torchvision.utils.make_grid(X.detach().cpu(), nrow=X.shape[0])
            X = (X * 0.5 + 0.5) * 255
            return X

        def make_image(Xs, Xt, Y_hat):
            Xs = get_grid_image(Xs)
            Xt = get_grid_image(Xt)
            Y_hat = get_grid_image(Y_hat)
            return torch.cat((Xs, Xt, Y_hat), dim=1).numpy()

        with torch.no_grad():
            embed = self.arcface(
                F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))
            self.generator.eval()
            Y_hat = self.generator(Xt, embed, return_attributes=False)
            self.generator.train()

        image = make_image(Xs, Xt, Y_hat)
        if not os.path.exists(f'results/{self.model_dir}'):
            os.makedirs(f'results/{self.model_dir}')
        cv2.imwrite(f'results/{self.model_dir}/{self.iter}.jpg',
                    image.transpose([1, 2, 0]))

    def get_opt_stats(self, optimizer, type=''):
        stats = {f'{type}_lr': optimizer.param_groups[0]['lr']}
        return stats

    def adjust_lr(self, optimizer):
        if self.iter <= self.warmup:
            lr = self.lr * self.iter / self.warmup
        else:
            lr = self.lr * (1 + cos(pi * (self.iter - self.warmup) /
                                    (self.max_iters - self.warmup))) / 2

        for group in optimizer.param_groups:
            group['lr'] = lr
        return lr

    def write_logs(self, losses=None, metrics=None, stats=None, type='loss'):
        if losses:
            for name, value in losses.items():
                self.writer.add_scalar(f'{type}/{name}', value, self.iter)
        if metrics:
            for name, value in metrics.items():
                self.writer.add_scalar(f'metric/{name}', value, self.iter)
        if stats:
            for name, value in stats.items():
                self.writer.add_scalar(f'stats/{name}', value, self.iter)

    def save_generator(self, max_checkpoints=100):
        checkpoints = glob.glob(f'{self.model_dir}/*.pt')
        if len(checkpoints) > max_checkpoints:
            os.remove(checkpoints[-1])
        with open(f'checkpoints/{self.model_dir}/generator_{self.iter}.pt',
                  'wb') as f:
            torch.save(self.generator.state_dict(), f)

    def save_discriminator(self, max_checkpoints=100):
        checkpoints = glob.glob(f'{self.model_dir}/*.pt')
        if len(checkpoints) > max_checkpoints:
            os.remove(checkpoints[-1])
        with open(f'checkpoints/{self.model_dir}/discriminator_{self.iter}.pt',
                  'wb') as f:
            torch.save(self.discriminator.state_dict(), f)

    def load_discriminator(self, path, load_last=True):
        if load_last:
            try:
                checkpoints = glob.glob(f'{path}/discriminator*.pt')
                path = max(checkpoints, key=os.path.getctime)
            except (ValueError):
                print(f'Directory is empty: {path}')

        try:
            self.discriminator.load_state_dict(torch.load(path))
            self.cuda()
        except (FileNotFoundError):
            print(f'No such file: {path}')

    def load_generator(self, path, load_last=True):
        if load_last:
            try:
                checkpoints = glob.glob(f'{path}/generator*.pt')
                path = max(checkpoints, key=os.path.getctime)
            except (ValueError):
                print(f'Directory is empty: {path}')

        try:
            self.generator.load_state_dict(torch.load(path))
            iter_str = ''.join(filter(lambda x: x.isdigit(), path))
            self._iter = nn.Parameter(torch.tensor(int(iter_str)),
                                      requires_grad=False)
            self.cuda()
        except (FileNotFoundError):
            print(f'No such file: {path}')
Ejemplo n.º 24
0
def run(args):
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define model
    logger.info(f"Loading Model of {args.model_name}...")
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    hp.lambda_stft = config["lamda_stft"]
    hp.use_feature_map_loss = config["use_feature_map_loss"]

    if args.model_name == "melgan":
        model = MelGANGenerator(
            in_channels=config["in_channels"],
            out_channels=config["out_channels"],
            kernel_size=config["kernel_size"],
            channels=config["channels"],
            upsample_scales=config["upsample_scales"],
            stack_kernel_size=config["stack_kernel_size"],
            stacks=config["stacks"],
            use_weight_norm=config["use_weight_norm"],
            use_causal_conv=config["use_causal_conv"]).to(device)
    elif args.model_name == "hifigan":
        model = HiFiGANGenerator(
            resblock_kernel_sizes=config["resblock_kernel_sizes"],
            upsample_rates=config["upsample_rates"],
            upsample_initial_channel=config["upsample_initial_channel"],
            resblock_type=config["resblock_type"],
            upsample_kernel_sizes=config["upsample_kernel_sizes"],
            resblock_dilation_sizes=config["resblock_dilation_sizes"],
            transposedconv=config["transposedconv"],
            bias=config["bias"]).to(device)
    elif args.model_name == "multiband-hifigan":
        model = MultiBandHiFiGANGenerator(
            resblock_kernel_sizes=config["resblock_kernel_sizes"],
            upsample_rates=config["upsample_rates"],
            upsample_initial_channel=config["upsample_initial_channel"],
            resblock_type=config["resblock_type"],
            upsample_kernel_sizes=config["upsample_kernel_sizes"],
            resblock_dilation_sizes=config["resblock_dilation_sizes"],
            transposedconv=config["transposedconv"],
            bias=config["bias"]).to(device)
    elif args.model_name == "basis-melgan":
        basis_signal_weight = np.load(
            os.path.join("Basis-MelGAN-dataset", "basis_signal_weight.npy"))
        basis_signal_weight = torch.from_numpy(basis_signal_weight)
        model = BasisMelGANGenerator(
            basis_signal_weight=basis_signal_weight,
            L=config["L"],
            in_channels=config["in_channels"],
            out_channels=config["out_channels"],
            kernel_size=config["kernel_size"],
            channels=config["channels"],
            upsample_scales=config["upsample_scales"],
            stack_kernel_size=config["stack_kernel_size"],
            stacks=config["stacks"],
            use_weight_norm=config["use_weight_norm"],
            use_causal_conv=config["use_causal_conv"],
            transposedconv=config["transposedconv"]).to(device)
    else:
        raise Exception("no model find!")
    pqmf = None
    if config["multiband"] == True:
        logger.info("Define PQMF")
        pqmf = PQMF().to(device)
    logger.info(f"model is {str(model)}")
    discriminator = Discriminator().to(device)

    logger.info("Model Has Been Defined")
    num_param = get_param_num(model)
    logger.info(f'Number of TTS Parameters: {num_param}')

    # Optimizer and loss
    basis_signal_optimizer = None
    if not args.mixprecision:
        if args.model_name == "basis-melgan":
            optimizer = Adam(model.melgan.parameters(),
                             lr=args.learning_rate,
                             eps=1.0e-6,
                             weight_decay=0.0)
            # freeze basis signal layer
            basis_signal_optimizer = Adam(model.basis_signal.parameters())
        else:
            optimizer = Adam(model.parameters(),
                             lr=args.learning_rate,
                             eps=1.0e-6,
                             weight_decay=0.0)
        discriminator_optimizer = Adam(discriminator.parameters(),
                                       lr=args.learning_rate_discriminator,
                                       eps=1.0e-6,
                                       weight_decay=0.0)
    else:
        if args.model_name == "basis-melgan":
            raise Exception("basis melgan don't support amp!")
        optimizer = apex.optimizers.FusedAdam(model.parameters(),
                                              lr=args.learning_rate)
        discriminator_optimizer = apex.optimizers.FusedAdam(
            discriminator.parameters(), lr=args.learning_rate_discriminator)
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          keep_batchnorm_fp32=None)
        discriminator, discriminator_optimizer = amp.initialize(
            discriminator, discriminator_optimizer, opt_level="O1")
        logger.info("Start mix precision training...")

    if args.use_scheduler:
        scheduler = CosineAnnealingLR(optimizer,
                                      T_max=2500,
                                      eta_min=args.learning_rate / 10.)
        discriminator_scheduler = CosineAnnealingLR(
            discriminator_optimizer,
            T_max=2500,
            eta_min=args.learning_rate_discriminator / 10.)
    else:
        scheduler = None
        discriminator_scheduler = None
    vocoder_loss = Loss().to(device)
    logger.info("Defined Optimizer and Loss Function.")

    # Load checkpoint if exists
    os.makedirs(hp.checkpoint_path, exist_ok=True)
    current_checkpoint_path = str(datetime.now()).replace(" ", "-").replace(
        ":", "-").replace(".", "-")
    current_checkpoint_path = os.path.join(hp.checkpoint_path,
                                           current_checkpoint_path)
    try:
        checkpoint = torch.load(os.path.join(args.checkpoint_path),
                                map_location=torch.device(device))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if 'discriminator' in checkpoint:
            logger.info("loading discriminator")
            discriminator.load_state_dict(checkpoint['discriminator'])
            discriminator_optimizer.load_state_dict(
                checkpoint['discriminator_optimizer'])
        os.makedirs(current_checkpoint_path, exist_ok=True)
        if args.mixprecision:
            amp.load_state_dict(checkpoint['amp'])
        logger.info("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:
        logger.info("\n---Start New Training---\n")
        os.makedirs(current_checkpoint_path, exist_ok=True)

    # Init logger
    os.makedirs(hp.logger_path, exist_ok=True)
    current_logger_path = str(datetime.now()).replace(" ", "-").replace(
        ":", "-").replace(".", "-")
    writer = SummaryWriter(
        os.path.join(hp.tensorboard_path, current_logger_path))
    current_logger_path = os.path.join(hp.logger_path, current_logger_path)
    os.makedirs(current_logger_path, exist_ok=True)

    # Get buffer
    if args.model_name != "basis-melgan":
        logger.info("Load data to buffer")
        buffer = load_data_to_buffer(args.audio_index_path,
                                     args.mel_index_path,
                                     logger,
                                     feature_savepath="features_train.bin")
        logger.info("Load valid data to buffer")
        valid_buffer = load_data_to_buffer(
            args.audio_index_valid_path,
            args.mel_index_valid_path,
            logger,
            feature_savepath="features_valid.bin")

    # Get dataset
    if args.model_name == "basis-melgan":
        dataset = WeightDataset(args.audio_index_path, args.mel_index_path,
                                config["L"])
        valid_dataset = WeightDataset(args.audio_index_valid_path,
                                      args.mel_index_valid_path, config["L"])
    else:
        dataset = BufferDataset(buffer)
        valid_dataset = BufferDataset(valid_buffer)

    # Get Training Loader
    training_loader = DataLoader(dataset,
                                 batch_size=hp.batch_expand_size *
                                 hp.batch_size,
                                 shuffle=True,
                                 collate_fn=collate_fn_tensor,
                                 drop_last=True,
                                 num_workers=4,
                                 prefetch_factor=2,
                                 pin_memory=True)
    logger.info(f"Length of training loader is {len(training_loader)}")
    total_step = hp.epochs * len(training_loader) * hp.batch_expand_size

    # Define Some Information
    time_list = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        for i, batchs in enumerate(training_loader):

            # real batch start here
            for j, db in enumerate(batchs):
                current_step = i * hp.batch_expand_size + j + args.restore_step + epoch * len(
                    training_loader) * hp.batch_expand_size + 1

                # Get Data
                clock_1_s = time.perf_counter()
                mel = db["mel"].float().to(device)
                wav = db["wav"].float().to(device)
                mel = mel.contiguous().transpose(1, 2)
                weight = None
                if "weight" in db:
                    weight = db["weight"].float().to(device)
                clock_1_e = time.perf_counter()
                time_used_1 = round(clock_1_e - clock_1_s, 5)

                # Training
                clock_2_s = time.perf_counter()
                time_list = trainer(
                    model,
                    discriminator,
                    optimizer,
                    discriminator_optimizer,
                    scheduler,
                    discriminator_scheduler,
                    vocoder_loss,
                    mel,
                    wav,
                    epoch,
                    current_step,
                    total_step,
                    time_list,
                    Start,
                    current_checkpoint_path,
                    current_logger_path,
                    writer,
                    weight=weight,
                    basis_signal_optimizer=basis_signal_optimizer,
                    pqmf=pqmf,
                    mixprecision=args.mixprecision)
                clock_2_e = time.perf_counter()
                time_used_2 = round(clock_2_e - clock_2_s, 5)

                if current_step % hp.valid_step == 0:
                    logger.info("Start valid...")
                    valid_loader = DataLoader(
                        valid_dataset,
                        batch_size=1,
                        shuffle=True,
                        collate_fn=collate_fn_tensor_valid,
                        num_workers=0)
                    valid_loss_all = 0.
                    for ii, valid_batch in enumerate(valid_loader):
                        valid_mel = valid_batch["mel"].float().to(device)
                        valid_mel = valid_mel.contiguous().transpose(1, 2)
                        valid_wav = valid_batch["wav"].float().to(device)
                        with torch.no_grad():
                            if args.model_name == "basis-melgan":
                                valid_est_source, _ = model(valid_mel)
                            else:
                                valid_est_source = model(valid_mel)
                            valid_stft_loss, _ = vocoder_loss(valid_est_source,
                                                              valid_wav,
                                                              pqmf=pqmf)
                            valid_loss_all += valid_stft_loss.item()
                        if ii == hp.valid_num:
                            break
                    writer.add_scalar('valid_stft_loss',
                                      valid_loss_all / float(hp.valid_num),
                                      global_step=current_step)

    writer.export_scalars_to_json(os.path.join("all_scalars.json"))
    writer.close()
    return
Ejemplo n.º 25
0
Archivo: main.py Proyecto: FengHZ/DCGAN
def main(args=args):
    dataset_base_path = path.join(args.base_path, "dataset", "celeba")
    image_base_path = path.join(dataset_base_path, "img_align_celeba")
    split_dataset_path = path.join(dataset_base_path, "Eval", "list_eval_partition.txt")
    with open(split_dataset_path, "r") as f:
        split_annotation = f.read().splitlines()
    # create the data name list for train,test and valid
    train_data_name_list = []
    test_data_name_list = []
    valid_data_name_list = []
    for item in split_annotation:
        item = item.split(" ")
        if item[1] == '0':
            train_data_name_list.append(item[0])
        elif item[1] == '1':
            valid_data_name_list.append(item[0])
        else:
            test_data_name_list.append(item[0])
    attribute_annotation_dict = None
    if args.need_label:
        attribute_annotation_path = path.join(dataset_base_path, "Anno", "list_attr_celeba.txt")
        with open(attribute_annotation_path, "r") as f:
            attribute_annotation = f.read().splitlines()
        attribute_annotation = attribute_annotation[2:]
        attribute_annotation_dict = {}
        for item in attribute_annotation:
            img_name, attribute = item.split(" ", 1)
            attribute = tuple([eval(attr) for attr in attribute.split(" ") if attr != ""])
            assert len(attribute) == 40, "the attribute of item {} is not equal to 40".format(img_name)
            attribute_annotation_dict[img_name] = attribute
    discriminator = Discriminator(num_channel=args.num_channel, num_feature=args.dnf,
                                  data_parallel=args.data_parallel).cuda()
    generator = Generator(latent_dim=args.latent_dim, num_feature=args.gnf,
                          num_channel=args.num_channel, data_parallel=args.data_parallel).cuda()
    input("Begin the {} time's training, the train dataset has {} images and the valid has {} images".format(
        args.train_time, len(train_data_name_list), len(valid_data_name_list)))
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
    d_scheduler = ExponentialLR(d_optimizer, gamma=args.decay_lr)
    g_scheduler = ExponentialLR(g_optimizer, gamma=args.decay_lr)
    writer_log_dir = "{}/DCGAN/runs/train_time:{}".format(args.base_path, args.train_time)
    # Here we implement the resume part
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if not args.not_resume_arg:
                args = checkpoint['args']
                args.start_epoch = checkpoint['epoch']
            discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
            generator.load_state_dict(checkpoint["generator_state_dict"])
            d_optimizer.load_state_dict(checkpoint['discriminator_optimizer'])
            g_optimizer.load_state_dict(checkpoint['generator_optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            raise FileNotFoundError("Checkpoint Resume File {} Not Found".format(args.resume))
    else:
        if os.path.exists(writer_log_dir):
            flag = input("DCGAN train_time:{} will be removed, input yes to continue:".format(
                args.train_time))
            if flag == "yes":
                shutil.rmtree(writer_log_dir, ignore_errors=True)
    writer = SummaryWriter(log_dir=writer_log_dir)
    # Here we just use the train dset in training
    train_dset = CelebADataset(base_path=image_base_path, data_name_list=train_data_name_list,
                               image_size=args.image_size,
                               label_dict=attribute_annotation_dict)
    train_dloader = DataLoader(dataset=train_dset, batch_size=args.batch_size, shuffle=True,
                               num_workers=args.workers, pin_memory=True)
    criterion = nn.BCELoss()
    for epoch in range(args.start_epoch, args.epochs):
        train(train_dloader, generator, discriminator, g_optimizer, d_optimizer, criterion, writer, epoch)
        # adjust lr
        d_scheduler.step()
        g_scheduler.step()
        # save parameters
        save_checkpoint({
            'epoch': epoch + 1,
            'args': args,
            "discriminator_state_dict": discriminator.state_dict(),
            "generator_state_dict": generator.state_dict(),
            'discriminator_optimizer': d_optimizer.state_dict(),
            'generator_optimizer': g_optimizer.state_dict()
        })
Ejemplo n.º 26
0
def main():
    # define actor/critic/discriminator net and optimizer
    policy = Policy(discrete_action_sections, discrete_state_sections)
    value = Value()
    discriminator = Discriminator()
    optimizer_policy = torch.optim.Adam(policy.parameters(), lr=args.policy_lr)
    optimizer_value = torch.optim.Adam(value.parameters(), lr=args.value_lr)
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=args.discrim_lr)
    discriminator_criterion = nn.BCELoss()
    writer = SummaryWriter()

    # load expert data
    dataset = ExpertDataSet(args.expert_activities_data_path,
                            args.expert_cost_data_path)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=args.expert_batch_size,
                                  shuffle=False,
                                  num_workers=1)

    # load models
    # discriminator.load_state_dict(torch.load('./model_pkl/Discriminator_model_3.pkl'))
    # policy.transition_net.load_state_dict(torch.load('./model_pkl/Transition_model_3.pkl'))
    # policy.policy_net.load_state_dict(torch.load('./model_pkl/Policy_model_3.pkl'))
    # value.load_state_dict(torch.load('./model_pkl/Value_model_3.pkl'))

    print('#############  start training  ##############')

    # update discriminator
    num = 0
    for ep in tqdm(range(args.training_epochs)):
        # collect data from environment for ppo update
        start_time = time.time()
        memory = policy.collect_samples(args.ppo_buffer_size, size=10000)
        # print('sample_data_time:{}'.format(time.time()-start_time))
        batch = memory.sample()
        continuous_state = torch.stack(
            batch.continuous_state).squeeze(1).detach()
        discrete_action = torch.stack(
            batch.discrete_action).squeeze(1).detach()
        continuous_action = torch.stack(
            batch.continuous_action).squeeze(1).detach()
        next_discrete_state = torch.stack(
            batch.next_discrete_state).squeeze(1).detach()
        next_continuous_state = torch.stack(
            batch.next_continuous_state).squeeze(1).detach()
        old_log_prob = torch.stack(batch.old_log_prob).detach()
        mask = torch.stack(batch.mask).squeeze(1).detach()
        discrete_state = torch.stack(batch.discrete_state).squeeze(1).detach()
        d_loss = torch.empty(0, device=device)
        p_loss = torch.empty(0, device=device)
        v_loss = torch.empty(0, device=device)
        gen_r = torch.empty(0, device=device)
        expert_r = torch.empty(0, device=device)
        for _ in range(1):
            for expert_state_batch, expert_action_batch in data_loader:
                gen_state = torch.cat((discrete_state, continuous_state),
                                      dim=-1)
                gen_action = torch.cat((discrete_action, continuous_action),
                                       dim=-1)
                gen_r = discriminator(gen_state, gen_action)
                expert_r = discriminator(expert_state_batch,
                                         expert_action_batch)
                optimizer_discriminator.zero_grad()
                d_loss = discriminator_criterion(gen_r,
                                                 torch.zeros(gen_r.shape, device=device)) + \
                         discriminator_criterion(expert_r,
                                                 torch.ones(expert_r.shape, device=device))
                total_d_loss = d_loss - 10 * torch.var(gen_r.to(device))
                d_loss.backward()
                # total_d_loss.backward()
                optimizer_discriminator.step()
        writer.add_scalar('d_loss', d_loss, ep)
        # writer.add_scalar('total_d_loss', total_d_loss, ep)
        writer.add_scalar('expert_r', expert_r.mean(), ep)

        # update PPO
        gen_r = discriminator(
            torch.cat((discrete_state, continuous_state), dim=-1),
            torch.cat((discrete_action, continuous_action), dim=-1))
        optimize_iter_num = int(
            math.ceil(discrete_state.shape[0] / args.ppo_mini_batch_size))
        for ppo_ep in range(args.ppo_optim_epoch):
            for i in range(optimize_iter_num):
                num += 1
                index = slice(
                    i * args.ppo_mini_batch_size,
                    min((i + 1) * args.ppo_mini_batch_size,
                        discrete_state.shape[0]))
                discrete_state_batch, continuous_state_batch, discrete_action_batch, continuous_action_batch, \
                old_log_prob_batch, mask_batch, next_discrete_state_batch, next_continuous_state_batch, gen_r_batch = \
                    discrete_state[index], continuous_state[index], discrete_action[index], continuous_action[index], \
                    old_log_prob[index], mask[index], next_discrete_state[index], next_continuous_state[index], gen_r[
                        index]
                v_loss, p_loss = ppo_step(
                    policy, value, optimizer_policy, optimizer_value,
                    discrete_state_batch, continuous_state_batch,
                    discrete_action_batch, continuous_action_batch,
                    next_discrete_state_batch, next_continuous_state_batch,
                    gen_r_batch, old_log_prob_batch, mask_batch,
                    args.ppo_clip_epsilon)
            writer.add_scalar('p_loss', p_loss, num)
            writer.add_scalar('v_loss', v_loss, num)
            writer.add_scalar('gen_r', gen_r.mean(), num)

        print('#' * 5 + 'training episode:{}'.format(ep) + '#' * 5)
        print('d_loss', d_loss.item())
        # print('p_loss', p_loss.item())
        # print('v_loss', v_loss.item())
        print('gen_r:', gen_r.mean().item())
        print('expert_r:', expert_r.mean().item())

        memory.clear_memory()
        # save models
        torch.save(discriminator.state_dict(),
                   './model_pkl/Discriminator_model_4.pkl')
        torch.save(policy.transition_net.state_dict(),
                   './model_pkl/Transition_model_4.pkl')
        torch.save(policy.policy_net.state_dict(),
                   './model_pkl/Policy_model_4.pkl')
        torch.save(value.state_dict(), './model_pkl/Value_model_4.pkl')
Ejemplo n.º 27
0
class HiDDen(object):
    def __init__(self, config: HiDDenConfiguration, device: torch.device):
        self.enc_dec = EncoderDecoder(config).to(device)
        self.discr = Discriminator(config).to(device)
        self.opt_enc_dec = torch.optim.Adam(self.enc_dec.parameters())
        self.opt_discr = torch.optim.Adam(self.discr.parameters())

        self.config = config
        self.device = device
        self.bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
        self.mse_loss = nn.MSELoss().to(device)

        self.cover_label = 1
        self.encod_label = 0

    def train_on_batch(self, batch: list):
        '''
        Trains the network on a single batch consistring images and messages
        '''
        images, messages = batch
        batch_size = images.shape[0]
        self.enc_dec.train()
        self.discr.train()

        with torch.enable_grad():
            # ---------- Train the discriminator----------
            self.opt_discr.zero_grad()

            # train on cover
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encod_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discr(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)
            d_loss_on_cover.backward()

            # train on fake
            encoded_images, decoded_messages = self.enc_dec(images, messages)
            d_on_encoded = self.discr(encoded_images.detach())
            d_loss_on_encod = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)
            d_loss_on_encod.backward()
            self.opt_discr.step()

            #---------- Train the generator----------
            self.opt_enc_dec.zero_grad()

            d_on_encoded_for_enc = self.discr(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)
            g_loss_enc = self.mse_loss(encoded_images, images)
            g_loss_dec = self.mse_loss(decoded_messages, messages)

            g_loss = self.config.adversarial_loss * g_loss_adv \
                    + self.config.encoder_loss * g_loss_enc \
                    + self.config.decoder_loss * g_loss_dec
            g_loss.backward()
            self.opt_enc_dec.step()

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy())) \
                      / (batch_size * messages.shape[1])

        losses = {
            'loss': g_loss.item(),
            'encoder_mse': g_loss_enc.item(),
            'decoder_mse': g_loss_dec.item(),
            'bitwise-error': bitwise_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encod.item()
        }

        return losses, (encoded_images, decoded_messages)

    def validate_on_batch(self, batch: list):
        '''Run validation on a batch consist of [images, messages]'''
        images, messages = batch
        batch_size = images.shape[0]

        self.enc_dec.eval()
        self.discr.eval()

        with torch.no_grad():
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encod_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discr(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)

            encoded_images, decoded_messages = self.enc_dec(images, messages)
            d_on_encoded = self.discr(encoded_images)
            d_loss_on_encod = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            d_on_encoded_for_enc = self.discr(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)
            g_loss_enc = self.mse_loss(encoded_images, images)
            g_loss_dec = self.mse_loss(decoded_messages, messages)

            g_loss = self.config.adversarial_loss * g_loss_adv \
                    + self.config.encoder_loss * g_loss_enc \
                    + self.config.decoder_loss * g_loss_dec

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy()))\
                     / (batch_size * messages.shape[1])

        losses = {
            'loss': g_loss.item(),
            'encoder_mse': g_loss_enc.item(),
            'decoder_mse': g_loss_dec.item(),
            'bitwise-err': bitwise_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_enced_bce': d_loss_on_encod.item()
        }

        return losses, (encoded_images, decoded_messages)

    def to_stirng(self):
        return f'{str(self.enc_dec)}\n{str(self.discr)}'
Ejemplo n.º 28
0
def main():

    # set torch and numpy seed for reproducibility
    torch.manual_seed(27)
    np.random.seed(27)

    # tensorboard writer
    writer = SummaryWriter(settings.TENSORBOARD_DIR)
    # makedir snapshot
    makedir(settings.CHECKPOINT_DIR)

    # enable cudnn
    torch.backends.cudnn.enabled = True

    # create segmentor network
    model_G = Segmentor(pretrained=settings.PRETRAINED,
                        num_classes=settings.NUM_CLASSES,
                        modality=settings.MODALITY)

    model_G.train()
    model_G.cuda()

    torch.backends.cudnn.benchmark = True

    # create discriminator network
    model_D = Discriminator(settings.NUM_CLASSES)
    model_D.train()
    model_D.cuda()

    # dataset and dataloader
    dataset = TrainDataset()
    dataloader = data.DataLoader(dataset,
                                 batch_size=settings.BATCH_SIZE,
                                 shuffle=True,
                                 num_workers=settings.NUM_WORKERS,
                                 pin_memory=True,
                                 drop_last=True)

    test_dataset = TestDataset(data_root=settings.DATA_ROOT_VAL,
                               data_list=settings.DATA_LIST_VAL)
    test_dataloader = data.DataLoader(test_dataset,
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=settings.NUM_WORKERS,
                                      pin_memory=True)

    # optimizer for generator network (segmentor)
    optim_G = optim.SGD(model_G.optim_parameters(settings.LR),
                        lr=settings.LR,
                        momentum=settings.LR_MOMENTUM,
                        weight_decay=settings.WEIGHT_DECAY)

    # lr scheduler for optimi_G
    lr_lambda_G = lambda epoch: (1 - epoch / settings.EPOCHS
                                 )**settings.LR_POLY_POWER
    lr_scheduler_G = optim.lr_scheduler.LambdaLR(optim_G,
                                                 lr_lambda=lr_lambda_G)

    # optimizer for discriminator network
    optim_D = optim.Adam(model_D.parameters(), settings.LR_D)

    # lr scheduler for optimi_D
    lr_lambda_D = lambda epoch: (1 - epoch / settings.EPOCHS
                                 )**settings.LR_POLY_POWER
    lr_scheduler_D = optim.lr_scheduler.LambdaLR(optim_D,
                                                 lr_lambda=lr_lambda_D)

    # losses
    ce_loss = CrossEntropyLoss2d(
        ignore_index=settings.IGNORE_LABEL)  # to use for segmentor
    bce_loss = BCEWithLogitsLoss2d()  # to use for discriminator

    # upsampling for the network output
    upsample = nn.Upsample(size=(settings.CROP_SIZE, settings.CROP_SIZE),
                           mode='bilinear',
                           align_corners=True)

    # # labels for adversarial training
    # pred_label = 0
    # gt_label = 1

    # load the model to resume training
    last_epoch = -1
    if settings.RESUME_TRAIN:
        checkpoint = torch.load(settings.LAST_CHECKPOINT)

        model_G.load_state_dict(checkpoint['model_G_state_dict'])
        model_G.train()
        model_G.cuda()

        model_D.load_state_dict(checkpoint['model_D_state_dict'])
        model_D.train()
        model_D.cuda()

        optim_G.load_state_dict(checkpoint['optim_G_state_dict'])
        optim_D.load_state_dict(checkpoint['optim_D_state_dict'])

        lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G_state_dict'])
        lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D_state_dict'])

        last_epoch = checkpoint['epoch']

        # purge the logs after the last_epoch
        writer = SummaryWriter(settings.TENSORBOARD_DIR,
                               purge_step=(last_epoch + 1) * len(dataloader))

    for epoch in range(last_epoch + 1, settings.EPOCHS + 1):

        train_one_epoch(model_G,
                        model_D,
                        optim_G,
                        optim_D,
                        dataloader,
                        test_dataloader,
                        epoch,
                        upsample,
                        ce_loss,
                        bce_loss,
                        writer,
                        print_freq=5,
                        eval_freq=settings.EVAL_FREQ)

        if epoch % settings.CHECKPOINT_FREQ == 0 and epoch != 0:
            save_checkpoint(epoch, model_G, model_D, optim_G, optim_D,
                            lr_scheduler_G, lr_scheduler_D)

        # save the final model
        if epoch >= settings.EPOCHS:
            print('saving the final model')
            save_checkpoint(epoch, model_G, model_D, optim_G, optim_D,
                            lr_scheduler_G, lr_scheduler_D)
            writer.close()

        lr_scheduler_G.step()
        lr_scheduler_D.step()
Ejemplo n.º 29
0
def main():

    # parse input size
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    # cudnn.enabled = True
    # gpu = args.gpu

    # create segmentation network
    model = DeepLab(num_classes=args.num_classes)

    # load pretrained parameters
    # if args.restore_from[:4] == 'http' :
    #     saved_state_dict = model_zoo.load_url(args.restore_from)
    # else:
    #     saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    # new_params = model.state_dict().copy()
    # for name, param in new_params.items():
    #     if name in saved_state_dict and param.size() == saved_state_dict[name].size():
    #         new_params[name].copy_(saved_state_dict[name])
    # model.load_state_dict(new_params)

    model.train()
    model.cpu()
    # model.cuda(args.gpu)
    # cudnn.benchmark = True

    # create discriminator network
    model_D = Discriminator(num_classes=args.num_classes)
    # if args.restore_from_D is not None:
    #     model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cpu()
    # model_D.cuda(args.gpu)

    # MILESTONE 1
    print("Printing MODELS ...")
    print(model)
    print(model_D)

    # Create directory to save snapshots of the model
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # Load train data and ground truth labels
    # train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size,
    #                 scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)
    # train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
    #                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    # trainloader = data.DataLoader(train_dataset,
    #                 batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=False)
    # trainloader_gt = data.DataLoader(train_gt_dataset,
    #                 batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=False)

    train_dataset = MyCustomDataset()
    train_gt_dataset = MyCustomDataset()

    trainloader = data.DataLoader(train_dataset, batch_size=5, shuffle=True)
    trainloader_gt = data.DataLoader(train_gt_dataset,
                                     batch_size=5,
                                     shuffle=True)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # MILESTONE 2
    print("Printing Loaders")
    print(trainloader_iter)
    print(trainloader_gt_iter)

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # MILESTONE 3
    print("Printing OPTIMIZERS ...")
    print(optimizer)
    print(optimizer_D)

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            # if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv :
            #     try:
            #         _, batch = next(trainloader_remain_iter)
            #     except:
            #         trainloader_remain_iter = enumerate(trainloader_remain)
            #         _, batch = next(trainloader_remain_iter)

            #     # only access to img
            #     images, _, _, _ = batch
            #     images = Variable(images).cuda(args.gpu)

            #     pred = interp(model(images))
            #     pred_remain = pred.detach()

            #     D_out = interp(model_D(F.softmax(pred)))
            #     D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)

            #     ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

            #     loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain))
            #     loss_semi_adv = loss_semi_adv/args.iter_size

            #     #loss_semi_adv.backward()
            #     loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()/args.lambda_semi_adv

            #     if args.lambda_semi <= 0 or i_iter < args.semi_start:
            #         loss_semi_adv.backward()
            #         loss_semi_value = 0
            #     else:
            #         # produce ignore mask
            #         semi_ignore_mask = (D_out_sigmoid < args.mask_T)

            #         semi_gt = pred.data.cpu().numpy().argmax(axis=1)
            #         semi_gt[semi_ignore_mask] = 255

            #         semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size
            #         print('semi ratio: {:.4f}'.format(semi_ratio))

            #         if semi_ratio == 0.0:
            #             loss_semi_value += 0
            #         else:
            #             semi_gt = torch.FloatTensor(semi_gt)

            #             loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu)
            #             loss_semi = loss_semi/args.iter_size
            #             loss_semi_value += loss_semi.data.cpu().numpy()/args.lambda_semi
            #             loss_semi += loss_semi_adv
            #             loss_semi.backward()

            # else:
            #     loss_semi = None
            #     loss_semi_adv = None

            # train with source

            try:
                _, batch = next(trainloader_iter)
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cpu()
            # images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)

            # segmentation prediction
            pred = interp(model(images))
            # (spatial multi-class) cross entropy loss
            loss_seg = loss_calc(pred, labels)
            # loss_seg = loss_calc(pred, labels, args.gpu)

            # discriminator prediction
            D_out = interp(model_D(F.softmax(pred)))
            # adversarial loss
            loss_adv_pred = bce_loss(D_out,
                                     make_D_label(gt_label, ignore_mask))

            # multi-task loss
            # lambda_adv - weight for minimizing loss
            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # loss normalization
            loss = loss / args.iter_size

            # back propagation
            loss.backward()

            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            # if args.D_remain:
            #     pred = torch.cat((pred, pred_remain), 0)
            #     ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train with gt
            # get gt labels
            try:
                _, batch = next(trainloader_gt_iter)
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = next(trainloader_gt_iter)

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cpu()
            # D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')