def __init__(self, opt, resume_epoch=0):
        self.opt = opt
        self.pix2pix_model = Pix2PixModel(opt)
        if len(opt.gpu_ids) > 1:
            self.pix2pix_model = DataParallelWithCallback(
                self.pix2pix_model, device_ids=opt.gpu_ids)
            self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
        else:
            self.pix2pix_model.to(opt.gpu_ids[0])
            self.pix2pix_model_on_one_gpu = self.pix2pix_model

        if opt.use_ema:
            self.netG_ema = EMA(opt.ema_beta)
            for name, param in self.pix2pix_model_on_one_gpu.net[
                    'netG'].named_parameters():
                if param.requires_grad:
                    self.netG_ema.register(name, param.data)
            self.netCorr_ema = EMA(opt.ema_beta)
            for name, param in self.pix2pix_model_on_one_gpu.net[
                    'netCorr'].named_parameters():
                if param.requires_grad:
                    self.netCorr_ema.register(name, param.data)

        self.generated = None
        if opt.isTrain:
            self.optimizer_G, self.optimizer_D = \
                self.pix2pix_model_on_one_gpu.create_optimizers(opt)
            self.old_lr = opt.lr
            if opt.continue_train and opt.which_epoch == 'latest':
                checkpoint = torch.load(
                    os.path.join(opt.checkpoints_dir, opt.name,
                                 'optimizer.pth'))
                self.optimizer_G.load_state_dict(checkpoint['G'])
                self.optimizer_D.load_state_dict(checkpoint['D'])
        self.last_data, self.last_netCorr, self.last_netG, self.last_optimizer_G = None, None, None, None
Esempio n. 2
0
 def initialize_networks(self, opt):
     self.netG = networks.define_G(opt)
     self.netD = networks.define_D(opt)
     # set require gradients
     if self.isTrain:
         self.set_requires_grad([self.netG, self.netD], True)
     else:
         self.set_requires_grad([self.netG, self.netD], False)
     if self.use_gpu:
         self.netG = DataParallelWithCallback(self.netG,
                                              device_ids=opt['gpu_ids'])
         self.netD = DataParallelWithCallback(self.netD,
                                              device_ids=opt['gpu_ids'])
     self.train_nets = [self.netG, self.netD]
Esempio n. 3
0
    def __init__(self, opt):
        self.opt = opt
        self.pix2pix_model = Pix2PixModel(opt)
        if len(opt.gpu_ids) > 0:
            self.pix2pix_model = DataParallelWithCallback(
                self.pix2pix_model, device_ids=opt.gpu_ids)
            self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
        else:
            self.pix2pix_model_on_one_gpu = self.pix2pix_model

        self.generated = None
        if opt.isTrain:
            if not opt.unpairTrain:
                (
                    self.optimizer_G,
                    self.optimizer_D,
                ) = self.pix2pix_model_on_one_gpu.create_optimizers(opt)
            else:
                (
                    self.optimizer_G,
                    self.optimizer_D,
                    self.optimizer_D2,
                ) = self.pix2pix_model_on_one_gpu.create_optimizers(opt)
            self.old_lr = opt.lr

        self.d_losses = {}
        self.nanCount = 0
Esempio n. 4
0
    def initialize_networks(self, opt):
        self.netGA = networks.define_G(opt, opt['netGA'])
        self.netGB = networks.define_G(opt, opt['netGB'])
        self.netDA = networks.define_D(opt, opt['netDA'])
        self.netDB = networks.define_D(opt, opt['netDB'])
        self.netEA, self.netHairA = networks.define_RES(
            opt, opt['input_nc_A'], opt['netEDA'])
        self.netEB, self.netHairB = networks.define_RES(
            opt, opt['input_nc_B'], opt['netEDB'])

        if self.opt['pretrain']:
            self.train_nets = [
                self.netGA, self.netGB, self.netDA, self.netDB, self.netEA,
                self.netHairA, self.netEB, self.netHairB
            ]
        else:
            self.train_nets = [self.netEA, self.netHairA]

        # set require gradients
        if self.isTrain:
            self.set_requires_grad(self.train_nets, True)
        else:
            self.set_requires_grad(self.train_nets, False)

        if self.use_gpu:
            for i in range(len(self.train_nets)):
                self.train_nets[i] = DataParallelWithCallback(
                    self.train_nets[i], device_ids=opt['gpu_ids'])
            if self.opt['pretrain']:
                self.netGA, self.netGB, self.netDA, self.netDB, self.netEA, \
                    self.netHairA, self.netEB, self.netHairB = self.train_nets
            else:
                self.netEA, self.netHairA = self.train_nets
Esempio n. 5
0
 def __init__(self, opt):
     self.opt = opt
     self.pix2pix_model = create_model(opt)
     if len(opt.gpu_ids) > 0:
         self.pix2pix_model = DataParallelWithCallback(
             self.pix2pix_model,
             device_ids=opt.gpu_ids,
             output_device=opt.gpu_ids[-1],
             chunk_size=opt.chunk_size)
         self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
     else:
         self.pix2pix_model_on_one_gpu = self.pix2pix_model
     # self.Render = networks.Render(opt, render_size=opt.crop_size)
     self.generated = None
     if opt.isTrain:
         self.optimizer_G, self.optimizer_D = \
             self.pix2pix_model_on_one_gpu.create_optimizers(opt)
         self.old_lr = opt.lr
Esempio n. 6
0
    def __init__(self, opt):
        self.opt = opt
        self.pix2pix_model = Pix2PixModel(opt)

        if len(opt.gpu_ids) > 0:
            self.pix2pix_model = DataParallelWithCallback(
                self.pix2pix_model, device_ids=opt.gpu_ids)
            self.pix2pix_model.cuda()
            self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
        else:
            self.pix2pix_model_on_one_gpu = self.pix2pix_model

        self.generated = None

        if opt.isTrain:
            self.optimizer_G, self.optimizer_D = \
                self.pix2pix_model_on_one_gpu.create_optimizers(opt)
            self.old_lr = opt.lr
Esempio n. 7
0
    def __init__(self, opt):
        self.opt = opt
        if self.opt.model == 'pix2pix':
            self.pix2pix_model = Pix2pixModel(opt)
        elif self.opt.model == 'smis':
            self.pix2pix_model = SmisModel(opt)
        print(self.pix2pix_model)
        with open(os.path.join(opt.checkpoints_dir, opt.name, 'model.txt'),
                  'w') as f:
            f.write(self.pix2pix_model.__str__())
        if len(opt.gpu_ids) > 0:
            self.pix2pix_model = DataParallelWithCallback(
                self.pix2pix_model, device_ids=opt.gpu_ids)
            self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
        else:
            self.pix2pix_model_on_one_gpu = self.pix2pix_model

        self.generated = None
        if opt.isTrain:
            self.optimizer_G, self.optimizer_D = \
                self.pix2pix_model_on_one_gpu.create_optimizers(opt)
            self.old_lr = opt.lr
Esempio n. 8
0
    def __init__(self, opt):
        self.opt = opt
        self.pix2pix_model = Pix2PixModel(opt)
        #self.pix2pix_model = torch.nn.parallel.DistributedDataParallel(self.pix2pix_model,device_ids=[opt.gpu], find_unused_parameters=True)
        self.pix2pix_model = DataParallelWithCallback(self.pix2pix_model,
                                                      device_ids=opt.gpu_ids)
        self.pix2pix_model_on_one_gpu = self.pix2pix_model.module

        self.generated = None
        if opt.isTrain:
            self.optimizer_G, self.optimizer_D = self.pix2pix_model_on_one_gpu.create_optimizers(
                opt)
            self.old_lr = opt.lr
Esempio n. 9
0
    def __init__(self, opt, model):
        super(MyModel, self).__init__()
        self.opt = opt
        model = model.cuda(opt.gpu_ids[0])
        self.module = model

        self.model = DataParallelWithCallback(model, device_ids=opt.gpu_ids)
        if opt.batch_for_first_gpu != -1:
            self.bs_per_gpu = (opt.batchSize - opt.batch_for_first_gpu) // (
                len(opt.gpu_ids) - 1)  # batch size for each GPU
        else:
            self.bs_per_gpu = int(
                np.ceil(float(opt.batchSize) /
                        len(opt.gpu_ids)))  # batch size for each GPU
        self.pad_bs = self.bs_per_gpu * len(opt.gpu_ids) - opt.batchSize
Esempio n. 10
0
    def __init__(self, opt):
        self.opt = opt
        self.seg_inpaint_model = SegInpaintModel(opt)
        if len(opt.gpu_ids) > 0:
            self.seg_inpaint_model = DataParallelWithCallback(
                self.seg_inpaint_model, device_ids=opt.gpu_ids)
            self.seg_inpaint_model_on_one_gpu = self.seg_inpaint_model.module
        else:
            self.seg_inpaint_model_on_one_gpu = self.seg_inpaint_model

        self.generated = None

        self.optimizer_SPNet, self.optimizer_SGNet, self.optimizer_D_seg, self.optimizer_D_img = \
            self.seg_inpaint_model_on_one_gpu.create_optimizers(opt)

        self.old_lr = opt.lr
Esempio n. 11
0
 def __init__(self, opt, model):        
     super(MyModel, self).__init__()
     self.opt = opt
     model = model.cuda(opt.gpu_ids[0])
     self.module = model
     
     if opt.distributed:            
         self.model = nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
     else:
         #self.model = nn.DataParallel(model, device_ids=opt.gpu_ids)    
         self.model = DataParallelWithCallback(model, device_ids=opt.gpu_ids)
     if opt.batch_for_first_gpu != -1:
         self.bs_per_gpu = (opt.batchSize - opt.batch_for_first_gpu) // (len(opt.gpu_ids) - 1) # batch size for each GPU
     else:
         self.bs_per_gpu = int(np.ceil(float(opt.batchSize) / len(opt.gpu_ids))) # batch size for each GPU
     self.pad_bs = self.bs_per_gpu * len(opt.gpu_ids) - opt.batchSize           
Esempio n. 12
0
    def __init__(self, opt):
        self.opt = opt
        self.pix2pix_model = Pix2PixModel(opt)
        if len(opt.gpu_ids) > 0:
            self.pix2pix_model = DataParallelWithCallback(self.pix2pix_model,
                                                          device_ids=opt.gpu_ids)
            self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
        else:
            self.pix2pix_model_on_one_gpu = self.pix2pix_model

        self.generated = None
        if opt.isTrain:
            self.optimizer_G, self.optimizer_D = \
                self.pix2pix_model_on_one_gpu.create_optimizers(opt)
            self.old_lr = opt.lr

        self.amp = True if AMP and opt.use_amp else False
        
        if self.amp:
            self.scaler_G = GradScaler()
            self.scaler_D = GradScaler()
Esempio n. 13
0
    def __init__(self, opt):
        self.opt = opt
        if self.opt.dual:
            from models.pix2pix_dualmodel import Pix2PixModel
        elif self.opt.dual_segspade:
            from models.pix2pix_dual_segspademodel import Pix2PixModel
        elif opt.box_unpair:
            from models.pix2pix_dualunpair import Pix2PixModel
        else:
            from models.pix2pix_model import Pix2PixModel

        self.pix2pix_model = Pix2PixModel(opt)
        if len(opt.gpu_ids) > 0:
            self.pix2pix_model = DataParallelWithCallback(
                self.pix2pix_model, device_ids=opt.gpu_ids)
            self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
        else:
            self.pix2pix_model_on_one_gpu = self.pix2pix_model

        self.generated = None
        if opt.isTrain:
            self.optimizer_G, self.optimizer_D = \
                self.pix2pix_model_on_one_gpu.create_optimizers(opt)
            self.old_lr = opt.lr
Esempio n. 14
0
class DxdyModel(BaseModel):
    def name(self):
        return 'DxdyModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        pass

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # set networks
        self.initialize_networks(opt)

        # set loss functions
        self.initialize_loss(opt)

        # set optimizer
        self.initialize_optimizer(opt)

        self.initialize_other(opt)

        self.model_dict = {
            'netG': {
                'model': self.netG.module if self.use_gpu else self.netG,
                'optimizer': self.optimizer_G
            },
            'netD': {
                'model': self.netD.module if self.use_gpu else self.netD,
                'optimizer': self.optimizer_D
            }
        }
        self.opt = opt

    def initialize_networks(self, opt):
        self.netG = networks.define_G(opt)
        self.netD = networks.define_D(opt)
        # set require gradients
        if self.isTrain:
            self.set_requires_grad([self.netG, self.netD], True)
        else:
            self.set_requires_grad([self.netG, self.netD], False)
        if self.use_gpu:
            self.netG = DataParallelWithCallback(self.netG,
                                                 device_ids=opt['gpu_ids'])
            self.netD = DataParallelWithCallback(self.netD,
                                                 device_ids=opt['gpu_ids'])
        self.train_nets = [self.netG, self.netD]

    def initialize_optimizer(self, opt):
        G_params = list(self.netG.parameters())
        D_params = list(self.netD.parameters())
        beta1, beta2 = opt['beta1'], opt['beta2']
        G_lr, D_lr = opt.get('lr_G', opt['lr']), opt.get('lr_D', opt['lr'])
        self.optimizer_G = torch.optim.Adam(G_params,
                                            lr=G_lr,
                                            betas=(beta1, beta2))
        self.optimizer_D = torch.optim.Adam(D_params,
                                            lr=D_lr,
                                            betas=(beta1, beta2))
        self.old_lr = opt['lr']

    def initialize_loss(self, opt):
        self.criterionGAN = networks.GANLoss(opt['gan_mode'],
                                             tensor=self.FloatTensor,
                                             opt=opt)
        # if self.use_gpu:
        #     self.criterionGAN = DataParallelWithCallback(
        #         self.criterionGAN, device_ids=opt['gpu_ids'])
        self.criterionReg = torch.nn.L1Loss()

    def initialize_other(self, opt):
        full_body_mesh_vert_pos, full_body_mesh_face_inds = tools.load_body_mesh(
        )
        self.full_body_mesh_vert_pos = full_body_mesh_vert_pos.unsqueeze(0)
        self.full_body_mesh_face_inds = full_body_mesh_face_inds.unsqueeze(0)
        sample_dataset = haya_data.Hair3D10KConvDataOnly()
        self.sample_loader = torch.utils.data.DataLoader(
            sample_dataset,
            batch_size=opt['batch_size'],
            shuffle=False,
            num_workers=opt['workers'],
            drop_last=True)
        self.sample_iter = iter(self.sample_loader)
        assert len(sample_dataset) > 0
        print(f'{len(sample_dataset)} is loaded')

    def set_input(self, data):
        self.image = data['image'].to(self.device)
        self.mask = data['mask'].to(self.device)
        self.intensity = data['intensity'].to(self.device)
        self.gt_dxdy = data['dxdy'].to(torch.float).to(self.device)
        try:
            sample_data = next(self.sample_iter)
        except StopIteration:
            self.sample_iter = iter(self.sample_loader)
            sample_data = next(self.sample_iter)
        convdata = sample_data['convdata'].to(self.device)
        strands = convdata.permute(
            0, 2, 3, 4,
            1)[:, :3, :, :, :].contiguous()  # b x 3 x 32 x 32 x 300
        body_mesh_vert_pos = self.full_body_mesh_vert_pos.expand(
            strands.size(0), -1, -1).to(strands.device)
        body_mesh_face_inds = self.full_body_mesh_face_inds.expand(
            strands.size(0), -1, -1).to(strands.device)
        # generate random mvps
        mvps, _, _ = tools.generate_random_mvps(strands.size(0),
                                                strands.device)

        # render the 2D information
        self.strand_dxdy, self.strand_mask, body_mask, _, strand_vis, mvps, _ = tools.render(
            mvps,
            strands,
            body_mesh_vert_pos,
            body_mesh_face_inds,
            self.opt['im_size'],
            self.opt['expansion'],
            align_face=self.opt['align_face'],
            target_face_scale=self.opt['target_face_scale'])

    def forward(self):
        mask_ = self.mask.unsqueeze(1).type(self.image.dtype)
        strand_mask_ = self.strand_mask.unsqueeze(1).type(
            self.strand_dxdy.dtype)

        # for G
        self.pred_dxdy = self.netG(torch.cat([self.image, mask_], dim=1))
        if self.pred_dxdy.size(-1) != mask_.size(-1):
            mask_ = torch.nn.functional.interpolate(
                mask_, size=self.pred_dxdy.shape[-2:], mode='nearest')
        fake_sample = self.pred_dxdy * mask_.type(self.pred_dxdy.dtype)
        self.g_fake_score = self.netD(fake_sample)
        # for D
        fake_sample = self.pred_dxdy.detach() * mask_
        fake_sample.requires_grad_()
        real_sample = self.strand_dxdy * strand_mask_

        self.d_real_score, self.d_fake_score = self.netD(
            real_sample), self.netD(fake_sample)
        # for vis
        self.mask_ = mask_

        # scale the size of everything
        if self.pred_dxdy.size(-1) != self.mask.size(-1):
            self.mask = torch.nn.functional.interpolate(
                self.mask.unsqueeze(1),
                size=self.pred_dxdy.shape[-2:],
                mode='nearest').squeeze()
            self.intensity = torch.nn.functional.interpolate(
                self.intensity.unsqueeze(1),
                size=self.pred_dxdy.shape[-2:],
                mode='nearest').squeeze()
            self.gt_dxdy = torch.nn.functional.interpolate(
                self.gt_dxdy, size=self.pred_dxdy.shape[-2:], mode='nearest')

    def update_visuals(self):
        masked_pred_dxdy = torch.where(self.mask_ > 0., self.pred_dxdy,
                                       -torch.ones_like(self.pred_dxdy))
        masked_gt_dxdy = torch.where(self.mask_ > 0., self.gt_dxdy,
                                     -torch.ones_like(self.gt_dxdy))

        self.vis_dict['image'] = data_utils.make_grid_n(self.image[:6])
        self.vis_dict['gt_dxdy'] = data_utils.vis_orient(self.gt_dxdy[:6])
        self.vis_dict['pred_dxdy'] = data_utils.vis_orient(self.pred_dxdy[:6])
        self.vis_dict['masked_pred_dxdy'] = data_utils.vis_orient(
            masked_pred_dxdy[:6])
        self.vis_dict['masked_gt_dxdy'] = data_utils.vis_orient(
            masked_gt_dxdy[:6])
        self.vis_dict['render_dxdy'] = data_utils.vis_orient(
            self.strand_dxdy[:6])

    def backward_G(self):
        reg_loss = self.dxdy_reg_loss(self.pred_dxdy,
                                      self.gt_dxdy) * self.intensity
        reg_loss = (self.mask.expand_as(reg_loss).float() *
                    reg_loss).mean() * self.opt.get('lambda_reg', 1.)
        g_loss = self.criterionGAN(self.g_fake_score,
                                   True,
                                   for_discriminator=False)

        sum([g_loss, reg_loss]).mean().backward()
        self.loss_dict['loss_reg'] = reg_loss.item()
        self.loss_dict['loss_g'] = g_loss.item()

    def backward_D(self):
        d_fake = self.criterionGAN(self.d_fake_score, False)
        d_real = self.criterionGAN(self.d_real_score, True)

        sum([d_fake, d_real]).mean().backward()
        self.loss_dict['loss_d_fake'] = d_fake.item()
        self.loss_dict['loss_d_real'] = d_real.item()

    def optimize_parameters(self):
        for net in self.train_nets:
            net.train()

        self.forward()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

    ##################################################################
    # Helper functions
    ##################################################################
    def update_learning_rate(self, epoch):
        if epoch > self.opt['niter']:
            lrd = self.opt['lr'] / self.opt['niter_decay']
            new_lr = self.old_lr - lrd
        else:
            new_lr = self.old_lr

        if new_lr != self.old_lr:
            new_lr_G = new_lr / 2
            new_lr_D = new_lr * 2

            for param_group in self.optimizer_D.param_groups:
                param_group['lr'] = new_lr_D
            for param_group in self.optimizer_G.param_groups:
                param_group['lr'] = new_lr_G
            print('update learning rate: %f -> %f' % (self.old_lr, new_lr))
            self.old_lr = new_lr

    def dxdy_reg_loss(self, y_hat, y):
        '''
        y_hat, y: B 2 H W
        return: B H W
        '''
        y_norm = y_hat / (torch.norm(y_hat, dim=1, keepdim=True) + 0.0000001)
        cos = torch.abs(torch.sum(y_norm * y, dim=1, keepdim=False))
        norm = torch.abs(
            torch.norm(y_hat, dim=1, keepdim=False) - torch.ones_like(cos))
        return 1 - cos + norm

    def discriminate(self, fake_image, real_image):
        fake_concat = torch.cat([fake_image], dim=1)
        real_concat = torch.cat([real_image], dim=1)

        # In Batch Normalization, the fake and real images are
        # recommended to be in the same batch to avoid disparate
        # statistics in fake and real images.
        # So both fake and real images are fed to D all at once.
        fake_and_real = torch.cat([fake_concat, real_concat], dim=0)

        discriminator_out = self.netD(fake_and_real)

        pred_fake, pred_real = self.divide_pred(discriminator_out)

        return pred_fake, pred_real

    # Take the prediction of fake and real images from the combined batch
    def divide_pred(self, pred):
        # the prediction contains the intermediate outputs of multiscale GAN,
        # so it's usually a list
        if type(pred) == list:
            fake = []
            real = []
            for p in pred:
                fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
                real.append([tensor[tensor.size(0) // 2:] for tensor in p])
        else:
            fake = pred[:pred.size(0) // 2]
            real = pred[pred.size(0) // 2:]

        return fake, real

    def inference(self, data):
        with torch.no_grad():
            image = data['image'].to(self.device)
            mask = data['mask'].to(self.device)
            mask_ = mask.unsqueeze(1).type(image.dtype)
            pred_dxdy = self.netG(torch.cat([image, mask_], dim=1))
            masked_pred_dxdy = torch.where(mask_ > 0., pred_dxdy,
                                           -torch.ones_like(pred_dxdy))
            return {'image': image, 'mask': mask, 'pred_dxdy': pred_dxdy}
Esempio n. 15
0
class Pix2PixTrainer():
    """
    Trainer creates the model and optimizers, and uses them to
    updates the weights of the network while reporting losses
    and the latest visuals to visualize the progress in training.
    """
    def __init__(self, opt):
        self.opt = opt
        self.pix2pix_model = Pix2PixModel(opt)

        if len(opt.gpu_ids) > 0:
            self.pix2pix_model = DataParallelWithCallback(
                self.pix2pix_model, device_ids=opt.gpu_ids)
            self.pix2pix_model.cuda()
            self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
        else:
            self.pix2pix_model_on_one_gpu = self.pix2pix_model

        self.generated = None

        if opt.isTrain:
            self.optimizer_G, self.optimizer_D = \
                self.pix2pix_model_on_one_gpu.create_optimizers(opt)
            self.old_lr = opt.lr

        # print(self.pix2pix_model_on_one_gpu.netG)
        # print(self.pix2pix_model_on_one_gpu.netD)

    def run_generator_one_step(self, data):
        self.optimizer_G.zero_grad()
        g_losses, generated, masked, semantics = self.pix2pix_model(
            data, mode='generator')
        g_loss = sum(g_losses.values()).mean()
        g_loss.backward()
        self.optimizer_G.step()
        self.g_losses = g_losses
        self.generated = generated
        self.masked = masked
        self.semantics = semantics

    def run_discriminator_one_step(self, data):
        self.optimizer_D.zero_grad()
        d_losses = self.pix2pix_model(data, mode='discriminator')
        d_loss = sum(d_losses.values()).mean()
        d_loss.backward()

    def run_discriminator_one_step(self, data):
        self.optimizer_D.zero_grad()
        d_losses = self.pix2pix_model(data, mode='discriminator')
        d_loss = sum(d_losses.values()).mean()
        d_loss.backward()
        self.optimizer_D.step()
        self.d_losses = d_losses

    def get_latest_losses(self):
        return {**self.g_losses, **self.d_losses}

    def get_latest_generated(self):
        return self.generated

    def get_latest_real(self):
        return self.pix2pix_model_on_one_gpu.real_shape

    def get_semantics(self):
        return self.semantics

    def get_mask(self):
        if self.masked.shape[1] == 3:
            return self.masked
        else:
            return self.masked[:, :3]

    def update_learning_rate(self, epoch):
        self.update_learning_rate(epoch)

    def save(self, epoch):
        self.pix2pix_model_on_one_gpu.save(epoch)

    ##################################################################
    # Helper functions
    ##################################################################

    def update_learning_rate(self, epoch):
        if epoch > self.opt.niter:
            lrd = self.opt.lr / self.opt.niter_decay
            new_lr = self.old_lr - lrd
        else:
            new_lr = self.old_lr

        if new_lr != self.old_lr:
            if self.opt.no_TTUR:
                new_lr_G = new_lr
                new_lr_D = new_lr
            else:
                new_lr_G = new_lr / 2
                new_lr_D = new_lr * 2

            for param_group in self.optimizer_D.param_groups:
                param_group['lr'] = new_lr_D
            for param_group in self.optimizer_G.param_groups:
                param_group['lr'] = new_lr_G
            print('update learning rate: %f -> %f' % (self.old_lr, new_lr))
            self.old_lr = new_lr
Esempio n. 16
0
class RotateTrainer(object):
    """
    Trainer creates the model and optimizers, and uses them to
    updates the weights of the network while reporting losses
    and the latest visuals to visualize the progress in training.
    """
    def __init__(self, opt):
        self.opt = opt
        self.pix2pix_model = create_model(opt)
        if len(opt.gpu_ids) > 0:
            self.pix2pix_model = DataParallelWithCallback(
                self.pix2pix_model,
                device_ids=opt.gpu_ids,
                output_device=opt.gpu_ids[-1],
                chunk_size=opt.chunk_size)
            self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
        else:
            self.pix2pix_model_on_one_gpu = self.pix2pix_model
        # self.Render = networks.Render(opt, render_size=opt.crop_size)
        self.generated = None
        if opt.isTrain:
            self.optimizer_G, self.optimizer_D = \
                self.pix2pix_model_on_one_gpu.create_optimizers(opt)
            self.old_lr = opt.lr

    def use_gpu(self):
        return len(self.opt.gpu_ids) > 0

    def run_generator_one_step(self, data):
        self.optimizer_G.zero_grad()
        g_losses, generated = self.pix2pix_model.forward(data=data,
                                                         mode='generator')
        if not self.opt.train_rotate:
            with torch.no_grad():
                g_rotate_losses, generated_rotate = self.pix2pix_model.forward(
                    data=data, mode='generator_rotated')

        else:
            g_rotate_losses, generated_rotate = self.pix2pix_model.forward(
                data=data, mode='generator_rotated')
            g_losses['GAN_rotate'] = g_rotate_losses['GAN']
        g_loss = sum(g_losses.values()).mean()
        g_loss.backward()
        # g_rotate_loss = sum(g_rotate_losses.values()).mean()
        # g_rotate_loss.backward()
        self.optimizer_G.step()
        self.g_losses = g_losses
        # self.g_rotate_losses = g_rotate_losses
        self.generated = generated
        self.generated_rotate = generated_rotate

    def run_discriminator_one_step(self, data):
        self.optimizer_D.zero_grad()
        d_losses = self.pix2pix_model.forward(data=data, mode='discriminator')
        if self.opt.train_rotate:
            d_rotated_losses = self.pix2pix_model.forward(
                data=data, mode='discriminator_rotated')
            d_losses['D_rotate_Fake'] = d_rotated_losses['D_Fake']
            d_losses['D_rotate_real'] = d_rotated_losses['D_real']
        d_loss = sum(d_losses.values()).mean()
        d_loss.backward()
        self.optimizer_D.step()
        self.d_losses = d_losses

    def get_latest_generated(self):
        return self.generated

    def get_latest_generated_rotate(self):
        return self.generated_rotate

    def get_latest_losses(self):
        return {**self.g_losses, **self.d_losses}

    def get_current_visuals(self, data):
        return OrderedDict([('input_mesh', data['mesh']),
                            ('input_rotated_mesh', data['rotated_mesh']),
                            ('synthesized_image', self.get_latest_generated()),
                            ('synthesized_rotated_image',
                             self.get_latest_generated_rotate()),
                            ('real_image', data['image'])])

    def save(self, epoch):
        self.pix2pix_model_on_one_gpu.save(epoch)

    ##################################################################
    # Helper functions
    ##################################################################

    def update_learning_rate(self, epoch):
        if epoch > self.opt.niter:
            lrd = self.opt.lr / self.opt.niter_decay
            new_lr = self.old_lr - lrd
        else:
            new_lr = self.old_lr

        if new_lr != self.old_lr:
            if self.opt.no_TTUR:
                new_lr_G = new_lr
                new_lr_D = new_lr
            else:
                new_lr_G = new_lr / 2
                new_lr_D = new_lr * 2

            for param_group in self.optimizer_D.param_groups:
                param_group['lr'] = new_lr_D
            for param_group in self.optimizer_G.param_groups:
                param_group['lr'] = new_lr_G
            print('update learning rate: %f -> %f' % (self.old_lr, new_lr))
            self.old_lr = new_lr
class Pix2PixTrainer():
    """
    Trainer creates the model and optimizers, and uses them to
    updates the weights of the network while reporting losses
    and the latest visuals to visualize the progress in training.
    """
    def __init__(self, opt, resume_epoch=0):
        self.opt = opt
        self.pix2pix_model = Pix2PixModel(opt)
        if len(opt.gpu_ids) > 1:
            self.pix2pix_model = DataParallelWithCallback(
                self.pix2pix_model, device_ids=opt.gpu_ids)
            self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
        else:
            self.pix2pix_model.to(opt.gpu_ids[0])
            self.pix2pix_model_on_one_gpu = self.pix2pix_model

        if opt.use_ema:
            self.netG_ema = EMA(opt.ema_beta)
            for name, param in self.pix2pix_model_on_one_gpu.net[
                    'netG'].named_parameters():
                if param.requires_grad:
                    self.netG_ema.register(name, param.data)
            self.netCorr_ema = EMA(opt.ema_beta)
            for name, param in self.pix2pix_model_on_one_gpu.net[
                    'netCorr'].named_parameters():
                if param.requires_grad:
                    self.netCorr_ema.register(name, param.data)

        self.generated = None
        if opt.isTrain:
            self.optimizer_G, self.optimizer_D = \
                self.pix2pix_model_on_one_gpu.create_optimizers(opt)
            self.old_lr = opt.lr
            if opt.continue_train and opt.which_epoch == 'latest':
                checkpoint = torch.load(
                    os.path.join(opt.checkpoints_dir, opt.name,
                                 'optimizer.pth'))
                self.optimizer_G.load_state_dict(checkpoint['G'])
                self.optimizer_D.load_state_dict(checkpoint['D'])
        self.last_data, self.last_netCorr, self.last_netG, self.last_optimizer_G = None, None, None, None

    def run_generator_one_step(self, data, alpha=1):
        self.optimizer_G.zero_grad()
        g_losses, out = self.pix2pix_model(data, mode='generator', alpha=alpha)
        g_loss = sum(g_losses.values()).mean()
        g_loss.backward()
        self.optimizer_G.step()
        self.g_losses = g_losses
        self.out = out
        if self.opt.use_ema:
            self.netG_ema(self.pix2pix_model_on_one_gpu.net['netG'])
            self.netCorr_ema(self.pix2pix_model_on_one_gpu.net['netCorr'])

    def run_discriminator_one_step(self, data):
        self.optimizer_D.zero_grad()
        GforD = {}
        GforD['fake_image'] = self.out['fake_image']
        GforD['adaptive_feature_seg'] = self.out['adaptive_feature_seg']
        GforD['adaptive_feature_img'] = self.out['adaptive_feature_img']
        d_losses = self.pix2pix_model(data, mode='discriminator', GforD=GforD)
        d_loss = sum(d_losses.values()).mean()
        d_loss.backward()
        self.optimizer_D.step()
        self.d_losses = d_losses

    def get_latest_losses(self):
        return {**self.g_losses, **self.d_losses}

    def get_latest_generated(self):
        return self.out['fake_image']

    def update_learning_rate(self, epoch):
        self.update_learning_rate(epoch)

    def save(self, epoch):
        self.pix2pix_model_on_one_gpu.save(epoch)
        if self.opt.use_ema:
            self.netG_ema.assign(self.pix2pix_model_on_one_gpu.net['netG'])
            util.save_network(self.pix2pix_model_on_one_gpu.net['netG'],
                              'G_ema', epoch, self.opt)
            self.netG_ema.resume(self.pix2pix_model_on_one_gpu.net['netG'])

            self.netCorr_ema.assign(
                self.pix2pix_model_on_one_gpu.net['netCorr'])
            util.save_network(self.pix2pix_model_on_one_gpu.net['netCorr'],
                              'netCorr_ema', epoch, self.opt)
            self.netCorr_ema.resume(
                self.pix2pix_model_on_one_gpu.net['netCorr'])
        if epoch == 'latest':
            torch.save(
                {
                    'G': self.optimizer_G.state_dict(),
                    'D': self.optimizer_D.state_dict(),
                    'lr': self.old_lr,
                },
                os.path.join(self.opt.checkpoints_dir, self.opt.name,
                             'optimizer.pth'))

    ##################################################################
    # Helper functions
    ##################################################################

    def update_learning_rate(self, epoch):
        if epoch > self.opt.niter:
            lrd = self.opt.lr / self.opt.niter_decay
            new_lr = self.old_lr - lrd
        else:
            new_lr = self.old_lr

        if new_lr != self.old_lr:
            if self.opt.no_TTUR:
                new_lr_G = new_lr
                new_lr_D = new_lr
            else:
                new_lr_G = new_lr / 2
                new_lr_D = new_lr * 2

            for param_group in self.optimizer_D.param_groups:
                param_group['lr'] = new_lr_D
            for param_group in self.optimizer_G.param_groups:
                param_group['lr'] = new_lr_G
            print('update learning rate: %f -> %f' % (self.old_lr, new_lr))
            self.old_lr = new_lr

    def update_fixed_params(self):
        for param in self.pix2pix_model_on_one_gpu.net['netCorr'].parameters():
            param.requires_grad = True
        G_params = [{
            'params':
            self.pix2pix_model_on_one_gpu.net['netG'].parameters(),
            'lr':
            self.opt.lr * 0.5
        }]
        G_params += [{
            'params':
            self.pix2pix_model_on_one_gpu.net['netCorr'].parameters(),
            'lr':
            self.opt.lr * 0.5
        }]
        if self.opt.no_TTUR:
            beta1, beta2 = self.opt.beta1, self.opt.beta2
            G_lr = self.opt.lr
        else:
            beta1, beta2 = 0, 0.9
            G_lr = self.opt.lr / 2

        self.optimizer_G = torch.optim.Adam(G_params,
                                            lr=G_lr,
                                            betas=(beta1, beta2),
                                            eps=1e-3)
Esempio n. 18
0
class Pix2PixTrainer():
    """
    Trainer creates the model and optimizers, and uses them to
    updates the weights of the network while reporting losses
    and the latest visuals to visualize the progress in training.
    """
    def __init__(self, opt):
        self.opt = opt
        if self.opt.model == 'pix2pix':
            self.pix2pix_model = Pix2pixModel(opt)
        elif self.opt.model == 'smis':
            self.pix2pix_model = SmisModel(opt)
        print(self.pix2pix_model)
        with open(os.path.join(opt.checkpoints_dir, opt.name, 'model.txt'),
                  'w') as f:
            f.write(self.pix2pix_model.__str__())
        if len(opt.gpu_ids) > 0:
            self.pix2pix_model = DataParallelWithCallback(
                self.pix2pix_model, device_ids=opt.gpu_ids)
            self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
        else:
            self.pix2pix_model_on_one_gpu = self.pix2pix_model

        self.generated = None
        if opt.isTrain:
            self.optimizer_G, self.optimizer_D = \
                self.pix2pix_model_on_one_gpu.create_optimizers(opt)
            self.old_lr = opt.lr

    def run_generator_one_step(self, data):
        self.optimizer_G.zero_grad()
        g_losses, generated = self.pix2pix_model(data, mode='generator')
        g_loss = sum(g_losses.values()).mean()
        g_loss.backward()
        self.optimizer_G.step()
        self.g_losses = g_losses
        self.generated = generated

    def run_discriminator_one_step(self, data):
        self.optimizer_D.zero_grad()
        d_losses = self.pix2pix_model(data, mode='discriminator')
        d_loss = sum(d_losses.values()).mean()
        d_loss.backward()
        self.optimizer_D.step()
        self.d_losses = d_losses

    def clean_grad(self):
        self.optimizer_D.zero_grad()
        self.optimizer_G.zero_grad()

    def get_latest_losses(self):
        return {**self.g_losses, **self.d_losses}

    def get_latest_generated(self):
        return self.generated

    def update_learning_rate(self, epoch):
        self.update_learning_rate(epoch)

    def save(self, epoch):
        self.pix2pix_model_on_one_gpu.save(epoch)

    ##################################################################
    # Helper functions
    ##################################################################

    def update_learning_rate(self, epoch):
        if epoch > self.opt.niter:
            lrd = self.opt.lr / self.opt.niter_decay
            new_lr = self.old_lr - lrd
        else:
            new_lr = self.old_lr

        if new_lr != self.old_lr:
            if self.opt.no_TTUR:
                new_lr_G = new_lr
                new_lr_D = new_lr
            else:
                new_lr_G = new_lr / 2
                new_lr_D = new_lr * 2

            for param_group in self.optimizer_D.param_groups:
                param_group['lr'] = new_lr_D
            for param_group in self.optimizer_G.param_groups:
                param_group['lr'] = new_lr_G
            print('update learning rate: %f -> %f' % (self.old_lr, new_lr))
            self.old_lr = new_lr