Esempio n. 1
0
class SRModel(BaseModel):
    def __init__(self, opt):
        super(SRModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            # loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.real_H = data['GT'].to(self.device)  # GT

    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)
        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        l_pix.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_audio_samples(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].cpu()
        out_dict['SR'] = self.fake_H.detach()[0].cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
Esempio n. 2
0
class RRSNetModel(BaseModel):
    def __init__(self, opt):
        super(RRSNetModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.l1_init = train_opt['l1_init'] if train_opt['l1_init'] else 0

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG,
                device_ids=[torch.cuda.current_device()],
                find_unused_parameters=True)
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netG.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # Branch_init_iters
            self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt[
                'Branch_pretrain'] else 0
            self.Branch_init_iters = train_opt[
                'Branch_init_iters'] if train_opt['Branch_init_iters'] else 1

            # gradient_pixel_loss
            self.cri_pix_grad = nn.MSELoss().to(self.device)
            self.l_pix_grad_w = train_opt['gradient_pixel_weight']

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
            self.get_grad = Get_gradient()
            self.get_grad_nopadding = Get_gradient_nopadding()

        self.print_network()  # print network

    def feed_data(self, data, need_GT=True):
        self.var_LQ = data['LQ'].to(self.device)  # LQ
        self.var_LQ_UX4 = data['LQ_UX4'].to(self.device)
        self.var_Ref = data['Ref'].to(self.device)
        self.var_Ref_DUX4 = data['Ref_DUX4'].to(self.device)

        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            self.var_ref = data['GT'].clone().to(self.device)

    def optimize_parameters(self, step):
        # G

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref,
                                self.var_Ref_DUX4)

        self.fake_H_grad = self.get_grad(self.fake_H)
        self.var_H_grad = self.get_grad(self.var_H)
        self.var_ref_grad = self.get_grad(self.var_ref)
        self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H)
        self.grad_LR = self.get_grad_nopadding(self.var_LQ)

        l_g_total = 0

        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
        l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(
            self.fake_H_grad, self.var_H_grad)
        l_g_total = l_pix + l_g_pix_grad
        l_g_total.backward()
        self.optimizer_G.step()
        self.log_dict['l_g_pix'] = l_pix.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref,
                                    self.var_Ref_DUX4)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
Esempio n. 3
0
class VideoBaseModel(BaseModel):
    def __init__(self, opt):
        super(VideoBaseModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            #### loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))

            self.l_pix_w = train_opt['pixel_weight']
            self.grad_w = train_opt['grad_weight']
            #### optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            if train_opt['ft_tsa_only']:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if 'tsa_fusion' in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))
                optim_params = [
                    {  # add normal params first
                        'params': normal_params,
                        'lr': train_opt['lr_G']
                    },
                    {
                        'params': tsa_fusion_params,
                        'lr': train_opt['lr_G']
                    },
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))

            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'], train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
                                                         restarts=train_opt['restarts'],
                                                         weights=train_opt['restart_weights'],
                                                         gamma=train_opt['lr_gamma'],
                                                         clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError()

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQs'].to(self.device)
        if need_GT:
            self.real_H = data['GT'].to(self.device)

    def set_params_lr_zero(self):
        # fix normal module
        self.optimizers[0].param_groups[0]['lr'] = 0

    def optimize_parameters(self, step):
        if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
            self.set_params_lr_zero()

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)
        pixel_loss = self.cri_pix(self.fake_H, self.real_H)
        g_loss = grad_loss(self.fake_H,self.real_H)
        l_pix = self.l_pix_w *pixel_loss  + self.grad_w*g_loss

        l_pix.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = pixel_loss.item()
        self.log_dict['l_grad'] = g_loss.item()

        self.log_dict['psnr'] = calPSNR(self.fake_H,self.real_H).item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
Esempio n. 4
0
class VideoBaseModel(BaseModel):
    def __init__(self, opt):
        super(VideoBaseModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()
            self.loss_type = train_opt['pixel_criterion']

            #### loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            elif loss_type == 'cb+ssim':
                self.cri_pix = CharbonnierLossPlusSSIM(lambda_=train_opt['ssim_weight']).to(self.device)
            elif loss_type == 'cb+msssim':
                self.cri_pix = CharbonnierLossPlusMSSSIM(lambda_=train_opt['ssim_weight']).to(self.device)
            elif loss_type == 'msssim':
                self.cri_pix = MSSSIMLoss().to(self.device)
            elif loss_type == 'ssim':
                self.cri_pix = SSIMLoss().to(self.device)
            elif loss_type == 'cb+ssim+vmaf':
                self.cri_pix = CharbonnierLossPlusSSIMPlusVMAF().to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            #### optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            if train_opt['ft_tsa_only']:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if 'tsa_fusion' in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))
                optim_params = [
                    {  # add normal params first
                        'params': normal_params,
                        'lr': train_opt['lr_G']
                    },
                    {
                        'params': tsa_fusion_params,
                        'lr': train_opt['lr_G']
                    },
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))

            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'], train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
                                                         restarts=train_opt['restarts'],
                                                         weights=train_opt['restart_weights'],
                                                         gamma=train_opt['lr_gamma'],
                                                         clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'], weights=train_opt['restart_weights']))

            elif train_opt['lr_scheme'] == 'ReduceLROnPlateau':
                for optimizer in self.optimizers:  # optimizers[0] =adam
                    self.schedulers.append(  # schedulers[0] = ReduceLROnPlateau
                        torch.optim.lr_scheduler.ReduceLROnPlateau(
                            optimizer, 'min', factor=train_opt['factor'], patience=train_opt['patience'],verbose=True))
                print('Use ReduceLROnPlateau')
            else:
                raise NotImplementedError()

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQs'].to(self.device)
        if need_GT:
            self.real_H = data['GT'].to(self.device)

    def set_params_lr_zero(self):
        # fix normal module
        self.optimizers[0].param_groups[0]['lr'] = 0


    def optimize_parameters(self, step):
        if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
            self.set_params_lr_zero()

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L) # 1 x 5 x 3 x 64 x 64

        loss, loss_tmp = self.cri_pix(self.fake_H, self.real_H)

        l_pix = self.l_pix_w * loss

        # if l_pix.item() > 1e-1:
        #     print('stop!')

        l_pix.backward()
        self.optimizer_G.step()

        if self.loss_type == 'cb+ssim':
            self.log_dict['total_loss'] = l_pix.item()
            self.log_dict['l_pix'] = loss_tmp[0].item()
            self.log_dict['ssim_loss'] = loss_tmp[1].item()
        else:
            self.log_dict['l_pix'] = l_pix.item()


    def optimize_parameters_without_schudlue(self, step):
        if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
            self.set_params_lr_zero()

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L) # 1 x 5 x 3 x 64 x 64

        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)

        if l_pix.item() > 1e-1:
            print('stop!')

        l_pix.backward()

        self.optimizer_G.step()

        # for scheduler in self.schedulers:
        #     scheduler.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def test_stitch(self):
        """
                To hande the 4k output, we have no much GPU memory
                :return:
                """
        self.netG.eval()

        with torch.no_grad():
            imgs_in = self.var_L  # 1 NC HW

            # crop
            gtWidth = 3840
            gtHeight = 2160
            intWidth_ori = 960  # 960
            intHeight_ori = 540  # 540
            split_lengthY = 180
            split_lengthX = 320
            scale = 4
            PAD = 32

            intPaddingRight_ = int(float(intWidth_ori) / split_lengthX + 1) * split_lengthX - intWidth_ori
            intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY + 1) * split_lengthY - intHeight_ori

            intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_
            intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_

            pader0 = torch.nn.ReplicationPad2d([0, intPaddingRight_, 0, intPaddingBottom_])
            # print("Init pad right/bottom " + str(intPaddingRight_) + " / " + str(intPaddingBottom_))

            intPaddingRight = PAD  # 32# 64# 128# 256
            intPaddingLeft = PAD  # 32#64 #128# 256
            intPaddingTop = PAD  # 32#64 #128#256
            intPaddingBottom = PAD  # 32#64 # 128# 256

            pader = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])

            imgs_in = torch.squeeze(imgs_in, 0)  # N C H W

            imgs_in = pader0(imgs_in)  # N C 540 960

            imgs_in = pader(imgs_in)  # N C 604 1024

            assert (split_lengthY == int(split_lengthY) and split_lengthX == int(split_lengthX))
            split_lengthY = int(split_lengthY)
            split_lengthX = int(split_lengthX)
            split_numY = int(float(intHeight_ori) / split_lengthY)
            split_numX = int(float(intWidth_ori) / split_lengthX)
            splitsY = range(0, split_numY)
            splitsX = range(0, split_numX)

            intWidth = split_lengthX
            intWidth_pad = intWidth + intPaddingLeft + intPaddingRight
            intHeight = split_lengthY
            intHeight_pad = intHeight + intPaddingTop + intPaddingBottom

            # print("split " + str(split_numY) + ' , ' + str(split_numX))
            # y_all = np.zeros((1, 3, gtHeight, gtWidth), dtype="float32")  # HWC
            y_all = torch.zeros((1, 3, gtHeight, gtWidth)).to(self.device)
            for split_j, split_i in itertools.product(splitsY, splitsX):
                # print(str(split_j) + ", \t " + str(split_i))
                X0 = imgs_in[:, :,
                     split_j * split_lengthY:(split_j + 1) * split_lengthY + intPaddingBottom + intPaddingTop,
                     split_i * split_lengthX:(split_i + 1) * split_lengthX + intPaddingRight + intPaddingLeft]

                # y_ = torch.FloatTensor()

                X0 = torch.unsqueeze(X0, 0)  # N C H W -> 1 N C H W

                output = self.netG(X0) # 1 N C H W ->  1 C H W

                # if flip_test:
                #     output = util.flipx4_forward(model, X0)
                # else:
                #     output = util.single_forward(model, X0)

                output_depadded = output[:, :, intPaddingTop * scale:(intPaddingTop + intHeight) * scale,  # 1 C H W
                                  intPaddingLeft * scale: (intPaddingLeft + intWidth) * scale]

                # output_depadded = output_depadded.squeeze(0)  # C H W

                # output = util.tensor2img(output_depadded)  # C H W -> HWC

                # y_all[split_j * split_lengthY * scale:(split_j + 1) * split_lengthY * scale,
                # split_i * split_lengthX * scale:(split_i + 1) * split_lengthX * scale, :] = \
                #     np.round(output_depadded).astype(np.uint8)

                y_all[:, :, split_j * split_lengthY * scale:(split_j + 1) * split_lengthY * scale,
                split_i * split_lengthX * scale:(split_i + 1) * split_lengthX * scale] = output_depadded

            self.fake_H = y_all  # 1 N x c x 2160 x 3840

            self.netG.train()


    def get_current_log(self):
        return self.log_dict

    def get_loss(self):
        if (self.opt['train']['pixel_criterion'] == 'cb+ssim' or self.opt['train']['pixel_criterion'] == 'cb' or self.opt['train']['pixel_criterion'] == 'ssim'
            or self.opt['train']['pixel_criterion'] == 'msssim' or self.opt['train']['pixel_criterion'] == 'cb+msssim'):
            loss_temp,_ = self.cri_pix(self.fake_H, self.real_H)
            l_pix = self.l_pix_w * loss_temp
        else:
            l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        return l_pix

    # def get_loss_ssim(self):
    #     l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
    #     # todo
    #     return l_pix



    def get_current_visuals(self, need_GT=True, save=False, name=None, save_path=None):

        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        if save == True:
            import os.path as osp
            import cv2
            img = out_dict['rlt']
            img = util.tensor2img(img)
            cv2.imwrite(osp.join(save_path, '{}.png'.format(name)), img)

        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
Esempio n. 5
0
class MWGANModel(BaseModel):
    def __init__(self, opt):
        super(MWGANModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training

        self.train_opt = opt['train']

        self.DWT = common.DWT()
        self.IWT = common.IWT()

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        # pretrained_dict = torch.load(opt['path']['pretrain_model_others'])
        # netG_dict = self.netG.state_dict()
        # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in netG_dict}
        # netG_dict.update(pretrained_dict)
        # self.netG.load_state_dict(netG_dict)

        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            if not self.train_opt['only_G']:
                self.netD = networks.define_D(opt).to(self.device)
                # init_weights(self.netD)
                if opt['dist']:
                    self.netD = DistributedDataParallel(
                        self.netD, device_ids=[torch.cuda.current_device()])
                else:
                    self.netD = DataParallel(self.netD)

                self.netG.train()
                self.netD.train()
            else:
                self.netG.train()
        else:
            self.netG.train()

        # define losses, optimizer and scheduler
        if self.is_train:

            # G pixel loss
            if self.train_opt['pixel_weight'] > 0:
                l_pix_type = self.train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'cb':
                    self.cri_pix = CharbonnierLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = self.train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            if self.train_opt['lpips_weight'] > 0:
                l_lpips_type = self.train_opt['lpips_criterion']
                if l_lpips_type == 'lpips':
                    self.cri_lpips = lpips.LPIPS(net='vgg').to(self.device)
                    if opt['dist']:
                        self.cri_lpips = DistributedDataParallel(
                            self.cri_lpips,
                            device_ids=[torch.cuda.current_device()])
                    else:
                        self.cri_lpips = DataParallel(self.cri_lpips)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(
                            l_lpips_type))
                self.l_lpips_w = self.train_opt['lpips_weight']
            else:
                logger.info('Remove lpips loss.')
                self.cri_lpips = None

            # G feature loss
            if self.train_opt['feature_weight'] > 0:
                self.fea_trans = GramMatrix().to(self.device)
                l_fea_type = self.train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif l_fea_type == 'cb':
                    self.cri_fea = CharbonnierLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = self.train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(self.train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = self.train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = self.train_opt[
                'D_update_ratio'] if self.train_opt['D_update_ratio'] else 1
            self.D_init_iters = self.train_opt[
                'D_init_iters'] if self.train_opt['D_init_iters'] else 0

            # optimizers
            # G
            wd_G = self.train_opt['weight_decay_G'] if self.train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=self.train_opt['lr_G'],
                weight_decay=wd_G,
                betas=(self.train_opt['beta1_G'], self.train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)

            if not self.train_opt['only_G']:
                # D
                wd_D = self.train_opt['weight_decay_D'] if self.train_opt[
                    'weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(
                    self.netD.parameters(),
                    lr=self.train_opt['lr_D'],
                    weight_decay=wd_D,
                    betas=(self.train_opt['beta1_D'],
                           self.train_opt['beta2_D']))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if self.train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            self.train_opt['lr_steps'],
                            restarts=self.train_opt['restarts'],
                            weights=self.train_opt['restart_weights'],
                            gamma=self.train_opt['lr_gamma'],
                            clear_state=self.train_opt['clear_state']))
            elif self.train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            self.train_opt['T_period'],
                            eta_min=self.train_opt['eta_min'],
                            restarts=self.train_opt['restarts'],
                            weights=self.train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        if self.is_train:
            if not self.train_opt['only_G']:
                self.print_network()  # print network
        else:
            self.print_network()  # print network

        try:
            self.load()  # load G and D if needed
            print('Pretrained model loaded')
        except Exception as e:
            print('No pretrained model found')

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            # print(self.var_H.size())
            self.var_H = self.var_H.squeeze(1)
            # self.var_H = self.DWT(self.var_H)

            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)
            # print(self.var_ref.size())
            self.var_ref = self.var_ref.squeeze(1)
            # print(s)
            # self.var_ref = self.DWT(self.var_ref)

    def process_list(self, input1, input2):
        result = []
        for index in range(len(input1)):
            result.append(input1[index] - torch.mean(input2[index]))
        return result

    def optimize_parameters(self, step):
        # G
        if not self.train_opt['only_G']:
            for p in self.netD.parameters():
                p.requires_grad = False

        self.optimizer_G.zero_grad()

        self.fake_H = self.netG(self.var_L)

        # self.var_H = self.var_H.squeeze(1)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix

            if self.cri_lpips:  # pixel loss
                l_g_lpips = torch.mean(
                    self.l_lpips_w *
                    self.cri_lpips.forward(self.fake_H, self.var_H))
                l_g_total += l_g_lpips

            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                real_fea_trans = self.fea_trans(real_fea)
                fake_fea_trans = self.fea_trans(fake_fea)
                l_g_fea_trans = self.l_fea_w * self.cri_fea(
                    fake_fea_trans, real_fea_trans) * 10
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea
                l_g_total += l_g_fea_trans

            if not self.train_opt['only_G']:
                pred_g_fake = self.netD(self.fake_H)

                if self.opt['train']['gan_type'] == 'gan':
                    l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
                elif self.opt['train']['gan_type'] == 'ragan':
                    # self.var_ref = self.var_ref[:,1:,:,:]
                    pred_d_real = self.netD(self.var_ref)
                    pred_d_real = [ele.detach() for ele in pred_d_real]
                    l_g_gan = self.l_gan_w * (self.cri_gan(
                        self.process_list(pred_d_real, pred_g_fake), False
                    ) + self.cri_gan(
                        self.process_list(pred_g_fake, pred_d_real), True)) / 2
                elif self.opt['train']['gan_type'] == 'lsgan_ra':
                    # self.var_ref = self.var_ref[:,1:,:,:]
                    pred_d_real = self.netD(self.var_ref)
                    pred_d_real = [ele.detach() for ele in pred_d_real]
                    # l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
                    l_g_gan = self.l_gan_w * (self.cri_gan(
                        self.process_list(pred_d_real, pred_g_fake), False
                    ) + self.cri_gan(
                        self.process_list(pred_g_fake, pred_d_real), True)) / 2
                elif self.opt['train']['gan_type'] == 'lsgan':
                    # self.var_ref = self.var_ref[:,1:,:,:]
                    l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
                l_g_total += l_g_gan

            l_g_total.backward()
            self.optimizer_G.step()
        else:
            self.var_ref = self.var_ref

        if not self.train_opt['only_G']:
            # D
            for p in self.netD.parameters():
                p.requires_grad = True

            self.optimizer_D.zero_grad()
            l_d_total = 0
            pred_d_real = self.netD(self.var_ref)
            pred_d_fake = self.netD(
                self.fake_H.detach())  # detach to avoid BP to G

            if self.opt['train']['gan_type'] == 'gan':
                l_d_real = self.cri_gan(pred_d_real, True)
                l_d_fake = self.cri_gan(pred_d_fake, False)
                l_d_total += l_d_real + l_d_fake
            elif self.opt['train']['gan_type'] == 'ragan':
                l_d_real = self.cri_gan(
                    self.process_list(pred_d_real, pred_d_fake), True)
                l_d_fake = self.cri_gan(
                    self.process_list(pred_d_fake, pred_d_real), False)
                l_d_total += (l_d_real + l_d_fake) / 2
            elif self.opt['train']['gan_type'] == 'lsgan':
                l_d_real = self.cri_gan(pred_d_real, True)
                l_d_fake = self.cri_gan(pred_d_fake, False)
                l_d_total += (l_d_real + l_d_fake) / 2

            l_d_total.backward()
            self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item() / self.l_pix_w
            if self.cri_lpips:
                self.log_dict['l_g_lpips'] = l_g_lpips.item() / self.l_lpips_w
            if not self.train_opt['only_G']:
                self.log_dict['l_g_gan'] = l_g_gan.item() / self.l_gan_w
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item() / self.l_fea_w
                self.log_dict['l_g_fea_trans'] = l_g_fea_trans.item(
                ) / self.l_fea_w / 10

        if not self.train_opt['only_G']:
            self.log_dict['l_d_real'] = l_d_real.item()
            self.log_dict['l_d_fake'] = l_d_fake.item()
            self.log_dict['D_real'] = torch.mean(pred_d_real[0].detach())
            self.log_dict['D_fake'] = torch.mean(pred_d_fake[0].detach())

    def test(self, load_path=None, input_u=None, input_v=None):

        if load_path is not None:
            self.load_network(load_path, self.netG,
                              self.opt['path']['strict_load'])
            print(
                '***************************************************************'
            )
            print('Load model successfully')
            print(
                '***************************************************************'
            )

        self.netG.eval()
        # self.var_H = self.var_H.squeeze(1)
        # img_to_write = self.var_L.detach()[0].float().cpu()
        # print(img_to_write.size())
        # cv2.imwrite('./test.png',img_to_write.numpy().transpose(1,2,0)*255)
        with torch.no_grad():
            if self.var_L.size()[-1] > 1280:
                width = self.var_L.size()[-1]
                height = self.var_L.size()[-2]
                fake_list = []
                for height_start in [0, int(height / 2)]:
                    for width_start in [0, int(width / 2)]:
                        self.fake_slice = self.netG(
                            self.var_L[:, :, :, height_start:(height_start +
                                                              int(height / 2)),
                                       width_start:(width_start +
                                                    int(width / 2))])
                        fake_list.append(self.fake_slice)
                enhanced_frame_h1 = torch.cat([fake_list[0], fake_list[2]], 2)
                enhanced_frame_h2 = torch.cat([fake_list[1], fake_list[3]], 2)
                self.fake_H = torch.cat([enhanced_frame_h1, enhanced_frame_h2],
                                        3)
            else:
                self.fake_H = self.netG(self.var_L)
            if input_u is not None and input_v is not None:
                self.var_L_u = input_u.to(self.device)
                self.var_L_v = input_v.to(self.device)
                self.fake_H_u_s = self.netG(self.var_L_u.float())
                self.fake_H_v_s = self.netG(self.var_L_v.float())
                # self.fake_H_u = torch.cat((self.fake_H_u_s[0], self.fake_H_u_s[1]), 1)
                # self.fake_H_v = torch.cat((self.fake_H_v_s[0], self.fake_H_v_s[1]), 1)
                self.fake_H_u = self.fake_H_u_s
                self.fake_H_v = self.fake_H_v_s
                # self.fake_H_u = self.IWT(self.fake_H_u)
                # self.fake_H_v = self.IWT(self.fake_H_v)
            else:
                self.fake_H_u = None
                self.fake_H_v = None
            self.fake_H_all = self.fake_H
            if self.opt['network_G']['out_nc'] == 4:
                self.fake_H_all = self.IWT(self.fake_H_all)
                if input_u is not None and input_v is not None:
                    self.fake_H_u = self.IWT(self.fake_H_u)
                    self.fake_H_v = self.IWT(self.fake_H_v)
        # self.fake_H = self.var_L[:,2,:,:,:]
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0][2].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if self.fake_H_u is not None:
            out_dict['SR_U'] = self.fake_H_u.detach()[0].float().cpu()
            out_dict['SR_V'] = self.fake_H_v.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network F structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])
            print('G loaded')
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])
            print('D loaded')

    def save(self, iter_step):
        if not self.train_opt['only_G']:
            self.save_network(self.netG, 'G', iter_step)
            self.save_network(self.netD, 'D', iter_step)
        else:
            self.save_network(self.netG,
                              self.opt['network_G']['which_model_G'],
                              iter_step, self.opt['path']['pretrain_model_G'])
Esempio n. 6
0
class Ranker_Model(BaseModel):
    def name(self):
        return 'Ranker_Model'

    def __init__(self, opt):
        super(Ranker_Model, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netR = networks.define_R(opt).to(self.device)
        if opt['dist']:
            self.netR = DistributedDataParallel(self.netR, device_ids=[torch.cuda.current_device()])
        else:
            self.netR = DataParallel(self.netR)
        self.load()

        if self.is_train:
            self.netR.train()

            # loss
            self.RankLoss = nn.MarginRankingLoss(margin=0.5)
            self.RankLoss.to(self.device)
            self.L2Loss = nn.L1Loss()
            self.L2Loss.to(self.device)
            # optimizers
            self.optimizers = []
            wd_R = train_opt['weight_decay_R'] if train_opt['weight_decay_R'] else 0
            optim_params = []
            for k, v in self.netR.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARNING: params [%s] will not optimize.' % k)
            self.optimizer_R = torch.optim.Adam(optim_params, lr=train_opt['lr_R'], weight_decay=wd_R)
            print('Weight_decay:%f' % wd_R)
            self.optimizers.append(self.optimizer_R)

            # schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                                                                    train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')

    def feed_data(self, data, need_img2=True):
        # input img1
        self.input_img1 = data['img1'].to(self.device)

        # label score1
        self.label_score1 = data['score1'].to(self.device)

        if need_img2:
            # input img2
            self.input_img2 = data['img2'].to(self.device)

            # label score2
            self.label_score2 = data['score2'].to(self.device)

            # rank label
            self.label = self.label_score1 >= self.label_score2  # get a ByteTensor
            # transfer into FloatTensor
            self.label = self.label.float()
            # label取值 -1 or 1
            self.label = (self.label - 0.5) * 2


    def optimize_parameters(self, step):
        self.optimizer_R.zero_grad()
        # 使用Rank计算image对应的score
        self.predict_score1 = self.netR(self.input_img1)
        self.predict_score2 = self.netR(self.input_img2)

        # 限制score的范围
        self.predict_score1 = torch.clamp(self.predict_score1, min=-5, max=5)
        self.predict_score2 = torch.clamp(self.predict_score2, min=-5, max=5)

        # 计算MarginRankLoss,最小化l_rank
        l_rank = self.RankLoss(self.predict_score1, self.predict_score2, self.label)

        l_rank.backward()
        self.optimizer_R.step()

        # set log
        self.log_dict['l_rank'] = l_rank.item()

    def test(self):
        self.netR.eval()
        self.predict_score1 = self.netR(self.input_img1)
        self.netR.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_HR=True):
        out_dict = OrderedDict()  # ............................
        out_dict['predict_score1'] = self.predict_score1.data[0].float().cpu()

        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netR)
        if isinstance(self.netR, nn.DataParallel):
            net_struc_str = '{} - {}'.format(self.netR.__class__.__name__,
                                             self.netR.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netR.__class__.__name__)
        logger.info('Network R structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
        logger.info(s)

    def load(self):

        load_path_R = self.opt['path']['pretrain_model_R']
        if load_path_R is not None:
            logger.info('Loading pretrained model for R [{:s}] ...'.format(load_path_R))
            self.load_network(load_path_R, self.netR)
    def save(self, iter_step):
        self.save_network(self.netR, 'R', iter_step)
Esempio n. 7
0
class VideoSRBaseModel(BaseModel):
    def __init__(self, opt):
        super(VideoSRBaseModel, self).__init__(opt)

        if opt["dist"]:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt["train"]

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt["dist"]:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            #### loss
            loss_type = train_opt["pixel_criterion"]
            if loss_type == "l1":
                self.cri_pix = nn.L1Loss(reduction="sum").to(self.device)
            elif loss_type == "l2":
                self.cri_pix = nn.MSELoss(reduction="sum").to(self.device)
            elif loss_type == "cb":
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    "Loss type [{:s}] is not recognized.".format(loss_type))
            self.l_pix_w = train_opt["pixel_weight"]

            #### optimizers
            wd_G = train_opt["weight_decay_G"] if train_opt[
                "weight_decay_G"] else 0
            if train_opt["ft_tsa_only"]:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if "tsa_fusion" in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                "Params [{:s}] will not optimize.".format(k))
                optim_params = [
                    {  # add normal params first
                        "params": normal_params,
                        "lr": train_opt["lr_G"],
                    },
                    {"params": tsa_fusion_params, "lr": train_opt["lr_G"]},
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                "Params [{:s}] will not optimize.".format(k))

            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1"], train_opt["beta2"]),
            )
            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt["lr_steps"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                            gamma=train_opt["lr_gamma"],
                            clear_state=train_opt["clear_state"],
                        ))
            elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt["T_period"],
                            eta_min=train_opt["eta_min"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                        ))
            else:
                raise NotImplementedError()

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data["LQs"].to(self.device)
        if need_GT:
            self.real_H = data["GT"].to(self.device)

    def set_params_lr_zero(self):
        # fix normal module
        self.optimizers[0].param_groups[0]["lr"] = 0

    def optimize_parameters(self, step):
        if self.opt["train"][
                "ft_tsa_only"] and step < self.opt["train"]["ft_tsa_only"]:
            self.set_params_lr_zero()

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)

        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        l_pix.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict["l_pix"] = l_pix.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict["LQ"] = self.var_L.detach()[0].float().cpu()
        out_dict["restore"] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict["GT"] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = "{} - {}".format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = "{}".format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                "Network G structure: {}, with parameters: {:,d}".format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt["path"]["pretrain_model_G"]
        if load_path_G is not None:
            logger.info("Loading model for G [{:s}] ...".format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt["path"]["strict_load"])

    def save(self, iter_label):
        self.save_network(self.netG, "G", iter_label)
Esempio n. 8
0
class FIRNModel(BaseModel):
    def __init__(self, opt):
        super(FIRNModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        test_opt = opt['test']
        self.train_opt = train_opt
        self.test_opt = test_opt

        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        self.Quantization = Quantization()

        if self.is_train:
            self.netG.train()

            # loss
            self.Reconstruction_forw = ReconstructionLoss(
                self.device, losstype=self.train_opt['pixel_criterion_forw'])
            self.Reconstruction_back = ReconstructionLoss(
                self.device, losstype=self.train_opt['pixel_criterion_back'])

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data):
        self.ref_L = data['LQ'].to(self.device)  # LQ
        self.real_H = data['GT'].to(self.device)  # GT

    def gaussian_batch(self, dims):
        return torch.randn(tuple(dims)).to(self.device)

    def loss_forward(self, out, y, z):
        l_forw_fit = self.train_opt[
            'lambda_fit_forw'] * self.Reconstruction_forw(out, y)

        z = z.reshape([out.shape[0], -1])
        l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum(
            z**2) / z.shape[0]

        return l_forw_fit, l_forw_ce

    def loss_backward(self, x, y):
        x_samples = self.netG(x=y, rev=True)
        x_samples_image = x_samples[:, :3, :, :]
        l_back_rec = self.train_opt[
            'lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image)

        return l_back_rec

    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()

        # forward downscaling
        self.input = self.real_H
        self.output = self.netG(x=self.input)

        zshape = self.output[:, 3:, :, :].shape
        LR_ref = self.ref_L.detach()

        l_forw_fit, l_forw_ce = self.loss_forward(self.output[:, :3, :, :],
                                                  LR_ref,
                                                  self.output[:, 3:, :, :])

        # backward upscaling
        LR = self.Quantization(self.output[:, :3, :, :])
        gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt[
            'gaussian_scale'] != None else 1
        y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)),
                       dim=1)

        l_back_rec = self.loss_backward(self.real_H, y_)

        # total loss
        loss = l_forw_fit + l_back_rec + l_forw_ce
        loss.backward()

        # gradient clipping
        if self.train_opt['gradient_clipping']:
            nn.utils.clip_grad_norm_(self.netG.parameters(),
                                     self.train_opt['gradient_clipping'])

        self.optimizer_G.step()

        # set log
        self.log_dict['l_forw_fit'] = l_forw_fit.item()
        self.log_dict['l_forw_ce'] = l_forw_ce.item()
        self.log_dict['l_back_rec'] = l_back_rec.item()

    def test(self):
        Lshape = self.ref_L.shape

        input_dim = Lshape[1]
        self.input = self.real_H

        zshape = [
            Lshape[0], input_dim * (self.opt['scale']**2) - Lshape[1],
            Lshape[2], Lshape[3]
        ]

        gaussian_scale = 1
        if self.test_opt and self.test_opt['gaussian_scale'] != None:
            gaussian_scale = self.test_opt['gaussian_scale']

        self.netG.eval()
        with torch.no_grad():
            self.forw_L = self.netG(x=self.input)[:, :3, :, :]
            self.forw_L = self.Quantization(self.forw_L)
            y_forw = torch.cat(
                (self.forw_L, gaussian_scale * self.gaussian_batch(zshape)),
                dim=1)
            self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :]

        self.netG.train()

    def downscale(self, HR_img):
        self.netG.eval()
        with torch.no_grad():
            LR_img = self.netG(x=HR_img)[:, :3, :, :]
            LR_img = self.Quantization(self.forw_L)
        self.netG.train()

        return LR_img

    def upscale(self, LR_img, scale, gaussian_scale=1):
        Lshape = LR_img.shape
        zshape = [Lshape[0], Lshape[1] * (scale**2 - 1), Lshape[2], Lshape[3]]
        y_ = torch.cat((LR_img, gaussian_scale * self.gaussian_batch(zshape)),
                       dim=1)

        self.netG.eval()
        with torch.no_grad():
            HR_img = self.netG(x=y_, rev=True)[:, :3, :, :]
        self.netG.train()

        return HR_img

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self):
        out_dict = OrderedDict()
        out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['LR'] = self.forw_L.detach()[0].float().cpu()
        out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
Esempio n. 9
0
class LRimgestimator_Model(BaseModel):
    def name(self):
        return 'Estimator_Model'

    def __init__(self, opt):
        super(LRimgestimator_Model, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.train_opt = train_opt
        self.kernel_size = opt['datasets']['train']['kernel_size']
        self.patch_size = opt['datasets']['train']['patch_size']
        self.batch_size = opt['datasets']['train']['batch_size']

        # define networks and load pretrained models
        self.scale = opt['scale']
        self.model_name = opt['network_E']['which_model_E']
        self.mode = opt['network_E']['mode']

        self.netE = networks.define_E(opt).to(self.device)
        if opt['dist']:
            self.netE = DistributedDataParallel(
                self.netE, device_ids=[torch.cuda.current_device()])
        else:
            self.netE = DataParallel(self.netE)
        self.load()

        # loss
        if train_opt['loss_ftn'] == 'l1':
            self.MyLoss = nn.L1Loss(reduction='mean').to(self.device)
        elif train_opt['loss_ftn'] == 'l2':
            self.MyLoss = nn.MSELoss(reduction='mean').to(self.device)
        else:
            self.MyLoss = None

        if self.is_train:
            self.netE.train()

            # optimizers
            self.optimizers = []
            wd_R = train_opt['weight_decay_R'] if train_opt[
                'weight_decay_R'] else 0
            optim_params = []
            for k, v in self.netE.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARNING: params [%s] will not optimize.' % k)
            self.optimizer_E = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_C'],
                                                weight_decay=wd_R)
            print('Weight_decay:%f' % wd_R)
            self.optimizers.append(self.optimizer_E)

            # schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                                                                    train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        #print('---------- Model initialized ------------------')
        #self.print_network()
        #print('-----------------------------------------------')

    def feed_data(self, data):
        self.real_H = data['LQs'].to(self.device)
        self.real_L = None if 'SuperLQs' not in data.keys(
        ) else data['SuperLQs'].to(self.device)
        B, T, C, H, W = self.real_H.shape
        if self.mode == 'image':
            self.var_H = self.real_H.reshape(B * T, C, H, W)
        else:
            self.var_H = self.real_H.transpose(1, 2)  # B C T H W

    def optimize_parameters(self, step=None):
        self.optimizer_E.zero_grad()
        fake_L = self.netE(self.var_H)
        if self.mode == 'image':
            H, W = fake_L.shape[-2:]
            B, T, C = self.real_H.shape[:3]
            self.fake_L = fake_L.reshape(B, T, C, H, W)
        else:
            self.fake_L = fake_L.transpose(1, 2)
        LR_loss = self.MyLoss(self.fake_L, self.real_L)
        # set log
        self.log_dict['l_pix'] = LR_loss.item()
        # Show the std of real, fake kernel
        LR_loss.backward()
        self.optimizer_E.step()

    def forward_without_optim(self, step=None):
        fake_L = self.netE(self.var_H)
        if self.mode == 'image':
            H, W = fake_L.shape[-2:]
            B, T, C = self.real_H.shape[:3]
            self.fake_L = fake_L.reshape(B, T, C, H, W)
        else:
            self.fake_L = fake_L.transpose(1, 2)

    def test(self):
        self.netE.eval()
        with torch.no_grad():
            fake_L = self.netE(self.var_H)
            if self.mode == 'image':
                H, W = fake_L.shape[-2:]
                B, T, C = self.real_H.shape[:3]
                self.fake_L = fake_L.reshape(B, T, C, H, W)
            else:
                self.fake_L = fake_L.transpose(1, 2)
        self.netE.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        T = self.fake_L.size(1)
        out_dict['LQ'] = self.real_L.detach()[0, T // 2].float().cpu()
        out_dict['rlt'] = self.fake_L.detach()[0, T // 2].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0, T // 2].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netE)
        if isinstance(self.netE, nn.DataParallel):
            net_struc_str = '{} - {}'.format(
                self.netE.__class__.__name__,
                self.netE.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netE.__class__.__name__)
        logger.info('Network R structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

    def load(self):
        load_path_E = self.opt['path']['pretrain_model_E']
        if load_path_E is not None:
            logger.info('Loading pretrained model for E [{:s}] ...'.format(
                load_path_E))
            self.load_network(load_path_E, self.netE)

    def save(self, iter_step):
        self.save_network(self.netE, 'E', iter_step)
Esempio n. 10
0
class RRDBM(BaseModel):
    def __init__(self, opt):
        super(RRDBM, self).__init__(opt)

        # define networks and load pretrained models
        train_opt = opt['train']

        self.netG_R = define_SR(opt).to(self.device)

        if opt['dist']:
            self.netG_R = DistributedDataParallel(
                self.netG_R, device_ids=[torch.cuda.current_device()])

        else:
            self.netG_R = DataParallel(self.netG_R)
        # define losses, optimizer and scheduler
        if self.is_train:
            # losses
            # if train_opt['l_pixel_type']=="L1":
            #     self.criterionPixel= torch.nn.L1Loss().to(self.device)
            # elif train_opt['l_pixel_type']=="CR":
            #     self.criterionPixel=CharbonnierLoss().to(self.device)
            #
            # else:
            #     raise NotImplementedError("pixel_type does not implement still")
            self.criterionPixel = SRLoss(
                loss_type=train_opt['l_pixel_type']).to(self.device)
            # optimizers
            self.optimizer_G = torch.optim.Adam(self.netG_R.parameters(),
                                                lr=train_opt['lr'],
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            #scheduler
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError("lr_scheme does not implement still")

            self.log_dict = OrderedDict()
            self.train_state()

        self.load()  # load R

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def feed_data(self, data):
        self.LQ = data['LQ'].to(self.device)
        self.HQ = data['HQ'].to(self.device)

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""

        self.fake_HQ = self.netG_R(self.LQ)

    def backward_G(self, step):
        """Calculate the loss for generators G_A and G_B"""

        self.loss_G_pixel = self.criterionPixel(self.fake_HQ, self.HQ)
        if len(self.loss_G_pixel) == 2:
            if self.opt['train']['other_step'] < step:
                self.loss_G_total = self.loss_G_pixel[0] * self.opt['train']['l_l1_weight']+ \
                                    self.loss_G_pixel[1] * self.opt['train']['l_ssim_weight']
            else:
                self.loss_G_total = self.loss_G_pixel[0] * self.opt['train'][
                    'l_l1_weight']
        else:

            self.loss_G_total = self.loss_G_pixel[0] * self.opt['train'][
                'l_l1_weight']

        self.loss_G_total.backward()

    def optimize_parameters(self, step):
        # G
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()  # compute fake images and reconstruction images.
        # G
        self.optimizer_G.zero_grad()  # set G gradients to zero
        self.backward_G(step)  # calculate gradients for G
        self.optimizer_G.step()  # update G's weights

        # set log
        for i in range(len(self.loss_G_pixel)):
            self.log_dict[str(i)] = self.loss_G_pixel[i].item()
        # self.log_dict['loss_l1'] = self.loss_G_pixel.item() if self.opt['train']['l_l1_weight']!=0 else 0

    def train_state(self):
        self.netG_R.train()

    def test_state(self):
        self.netG_R.eval()

    def val(self):
        self.test_state()
        with torch.no_grad():
            self.forward()
        self.train_state()

    def test(self, img):

        self.netG_R.eval()
        with torch.no_grad():

            SR = self.netG_R(img)
        return SR

    def get_network(self):
        return self.netG_R

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals_and_cal_metric(self, opt, current_step):

        visuals = [
            F.interpolate(self.LQ,
                          scale_factor=self.opt['datasets']['train']['scale'],
                          mode='bilinear',
                          align_corners=True), self.fake_HQ, self.HQ
        ]

        util.write_2images(visuals, opt['datasets']['val']['batch_size'],
                           opt['path']['val_images'],
                           'test_%08d' % (current_step))

        # HTML
        util.write_html(opt['path']['experiments_root'] + "/index.html",
                        (current_step), opt['train']['val_freq'],
                        opt['path']['val_images'])

        #src BRG range [0-255] HWC
        srimg = util.tensor2img(self.fake_HQ)
        hrimg = util.tensor2img(self.HQ)

        psnr = calculate_psnr(srimg, hrimg)
        ssim = calculate_ssim(srimg, hrimg)
        return {"psnr": psnr, "ssim": ssim}

    def print_network(self):

        if self.is_train:
            # Generator
            s, n = self.get_network_description(self.netG_R)
            net_struc_str = '{} - {}'.format(
                self.netG_R.__class__.__name__,
                self.netG_R.module.__class__.__name__)
            logger.info(
                'Network G_R structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G_R = self.opt['path']['pretrain_model_G_R']

        if load_path_G_R is not None:
            logger.info(
                'Loading models for G [{:s}] ...'.format(load_path_G_R))
            self.load_network(load_path_G_R, self.netG_R,
                              self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG_R, 'G_R', iter_step)
Esempio n. 11
0
class B_Model(BaseModel):
    def __init__(self, opt):
        super(B_Model, self).__init__(opt)

        if opt["dist"]:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt["dist"]:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()]
            )
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            train_opt = opt["train"]
            # self.init_model() # Not use init is OK, since Pytorch has its owen init (by default)
            self.netG.train()

            # loss
            loss_type = train_opt["pixel_criterion"]
            if loss_type == "l1":
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == "l2":
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == "cb":
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    "Loss type [{:s}] is not recognized.".format(loss_type)
                )
            self.l_pix_w = train_opt["pixel_weight"]

            # optimizers
            wd_G = train_opt["weight_decay_G"] if train_opt["weight_decay_G"] else 0
            optim_params = []
            for (
                k,
                v,
            ) in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning("Params [{:s}] will not optimize.".format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1"], train_opt["beta2"]),
            )
            # self.optimizer_G = torch.optim.SGD(optim_params, lr=train_opt['lr_G'], momentum=0.9)
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt["lr_steps"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                            gamma=train_opt["lr_gamma"],
                            clear_state=train_opt["clear_state"],
                        )
                    )
            elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt["T_period"],
                            eta_min=train_opt["eta_min"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                        )
                    )
            else:
                print("MultiStepLR learning rate scheme is enough.")

            self.log_dict = OrderedDict()

    def init_model(self, scale=0.1):
        # Common practise for initialization.
        for layer in self.netG.modules():
            if isinstance(layer, nn.Conv2d):
                init.kaiming_normal_(layer.weight, a=0, mode="fan_in")
                layer.weight.data *= scale  # for residual block
                if layer.bias is not None:
                    layer.bias.data.zero_()
            elif isinstance(layer, nn.Linear):
                init.kaiming_normal_(layer.weight, a=0, mode="fan_in")
                layer.weight.data *= scale
                if layer.bias is not None:
                    layer.bias.data.zero_()
            elif isinstance(layer, nn.BatchNorm2d):
                init.constant_(layer.weight, 1)
                init.constant_(layer.bias.data, 0.0)

    def feed_data(self, LR_img, GT_img=None, ker_map=None):
        self.var_L = LR_img.to(self.device)
        if not (GT_img is None):
            self.real_H = GT_img.to(self.device)
        if not (ker_map is None):
            self.real_ker = ker_map.to(self.device)

    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()
        srs, ker_maps = self.netG(self.var_L)

        self.fake_SR = srs[-1]
        self.fake_ker = ker_maps[-1]

        total_loss = 0
        for ind in range(len(ker_maps)):
            d_kr = self.cri_pix(ker_maps[ind], self.real_ker)

            d_sr = self.cri_pix(srs[ind], self.real_H)

            self.log_dict["l_pix%d" % ind] = d_sr.item()
            self.log_dict["l_ker%d" % ind] = d_kr.item()

        total_loss += d_sr
        total_loss += d_kr

        total_loss.backward()
        self.optimizer_G.step()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            srs, kermaps = self.netG(self.var_L)
            self.fake_SR = srs[-1]
            self.fake_ker = kermaps[-1]
        self.netG.train()

    def test_x8(self):
        # from https://github.com/thstkdgus35/EDSR-PyTorch
        self.netG.eval()

        def _transform(v, op):
            # if self.precision != 'single': v = v.float()
            v2np = v.data.cpu().numpy()
            if op == "v":
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == "h":
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == "t":
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = torch.Tensor(tfnp).to(self.device)
            # if self.precision == 'half': ret = ret.half()

            return ret

        lr_list = [self.var_L]
        for tf in "v", "h", "t":
            lr_list.extend([_transform(t, tf) for t in lr_list])
        with torch.no_grad():
            sr_list = [self.netG(aug)[0] for aug in lr_list]
        for i in range(len(sr_list)):
            if i > 3:
                sr_list[i] = _transform(sr_list[i], "t")
            if i % 4 > 1:
                sr_list[i] = _transform(sr_list[i], "h")
            if (i % 4) % 2 == 1:
                sr_list[i] = _transform(sr_list[i], "v")

        output_cat = torch.cat(sr_list, dim=0)
        self.fake_H = output_cat.mean(dim=0, keepdim=True)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self):
        out_dict = OrderedDict()
        out_dict["LQ"] = self.var_L.detach()[0].float().cpu()
        out_dict["SR"] = self.fake_SR.detach()[0].float().cpu()
        out_dict["GT"] = self.real_H.detach()[0].float().cpu()
        out_dict["ker"] = self.fake_ker.detach()[0].float().cpu()
        out_dict["Batch_SR"] = (
            self.fake_SR.detach().float().cpu()
        )  # Batch SR, for train
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
            self.netG, DistributedDataParallel
        ):
            net_struc_str = "{} - {}".format(
                self.netG.__class__.__name__, self.netG.module.__class__.__name__
            )
        else:
            net_struc_str = "{}".format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                "Network G structure: {}, with parameters: {:,d}".format(
                    net_struc_str, n
                )
            )
            logger.info(s)

    def load(self):
        load_path_G = self.opt["path"]["pretrain_model_G"]
        if load_path_G is not None:
            logger.info("Loading model for G [{:s}] ...".format(load_path_G))
            self.load_network(load_path_G, self.netG, self.opt["path"]["strict_load"])

    def save(self, iter_label):
        self.save_network(self.netG, "G", iter_label)
Esempio n. 12
0
class Model:
    """
    This class handles basic methods for handling the model:
    1. Fit the model
    2. Make predictions
    3. Save
    4. Load
    """
    def __init__(self, input_size, n_channels, hparams):

        self.hparams = hparams

        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        # define the models
        self.model = WaveNet(n_channels=n_channels).to(self.device)
        summary(self.model, (input_size, n_channels))
        # self.model.half()

        if torch.cuda.device_count() > 1:
            print("Number of GPUs will be used: ",
                  torch.cuda.device_count() - 3)
            self.model = DP(self.model,
                            device_ids=list(
                                range(torch.cuda.device_count() - 3)))
        else:
            print('Only one GPU is available')

        self.metric = Metric()
        self.num_workers = 1
        ########################## compile the model ###############################

        # define optimizer
        self.optimizer = torch.optim.Adam(params=self.model.parameters(),
                                          lr=self.hparams['lr'],
                                          weight_decay=1e-5)

        # weights = torch.Tensor([0.025,0.033,0.039,0.046,0.069,0.107,0.189,0.134,0.145,0.262,1]).cuda()
        self.loss = nn.BCELoss()  # CompLoss(self.device)

        # define early stopping
        self.early_stopping = EarlyStopping(
            checkpoint_path=self.hparams['checkpoint_path'] + '/checkpoint.pt',
            patience=self.hparams['patience'],
            delta=self.hparams['min_delta'],
        )
        # lr cheduler
        self.scheduler = ReduceLROnPlateau(
            optimizer=self.optimizer,
            mode='max',
            factor=0.2,
            patience=3,
            verbose=True,
            threshold=self.hparams['min_delta'],
            threshold_mode='abs',
            cooldown=0,
            eps=0,
        )

        self.seed_everything(42)
        self.threshold = 0.75
        self.scaler = torch.cuda.amp.GradScaler()

    def seed_everything(self, seed):
        np.random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        torch.manual_seed(seed)

    def fit(self, train, valid):

        train_loader = DataLoader(
            train,
            batch_size=self.hparams['batch_size'],
            shuffle=True,
            num_workers=self.num_workers)  # ,collate_fn=train.my_collate
        valid_loader = DataLoader(
            valid,
            batch_size=self.hparams['batch_size'],
            shuffle=False,
            num_workers=self.num_workers)  # ,collate_fn=train.my_collate

        # tensorboard object
        writer = SummaryWriter()

        for epoch in range(self.hparams['n_epochs']):

            # trian the model
            self.model.train()
            avg_loss = 0.0

            train_preds, train_true = torch.Tensor([]), torch.Tensor([])

            for (X_batch, y_batch) in tqdm(train_loader):
                y_batch = y_batch.float().to(self.device)
                X_batch = X_batch.float().to(self.device)

                self.optimizer.zero_grad()
                # get model predictions
                pred = self.model(X_batch)
                X_batch = X_batch.cpu().detach()

                # process loss_1
                pred = pred.view(-1, pred.shape[-1])
                y_batch = y_batch.view(-1, y_batch.shape[-1])
                train_loss = self.loss(pred, y_batch)
                y_batch = y_batch.float().cpu().detach()
                pred = pred.float().cpu().detach()

                train_loss.backward(
                )  #self.scaler.scale(train_loss).backward()  #
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
                # torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.5)
                self.optimizer.step()  # self.scaler.step(self.optimizer)  #
                self.scaler.update()

                # calc metric
                avg_loss += train_loss.item() / len(train_loader)

                train_true = torch.cat([train_true, y_batch], 0)
                train_preds = torch.cat([train_preds, pred], 0)

            # calc triaing metric
            train_preds = train_preds.numpy()
            train_preds[np.where(train_preds >= self.threshold)] = 1
            train_preds[np.where(train_preds < self.threshold)] = 0
            metric_train = self.metric.compute(labels=train_true.numpy(),
                                               outputs=train_preds)

            # evaluate the model
            print('Model evaluation...')
            self.model.zero_grad()
            self.model.eval()
            val_preds, val_true = torch.Tensor([]), torch.Tensor([])
            avg_val_loss = 0.0
            with torch.no_grad():
                for X_batch, y_batch in valid_loader:
                    y_batch = y_batch.float().to(self.device)
                    X_batch = X_batch.float().to(self.device)

                    pred = self.model(X_batch)
                    X_batch = X_batch.float().cpu().detach()

                    pred = pred.reshape(-1, pred.shape[-1])
                    y_batch = y_batch.view(-1, y_batch.shape[-1])

                    avg_val_loss += self.loss(
                        pred, y_batch).item() / len(valid_loader)
                    y_batch = y_batch.float().cpu().detach()
                    pred = pred.float().cpu().detach()

                    val_true = torch.cat([val_true, y_batch], 0)
                    val_preds = torch.cat([val_preds, pred], 0)

            # evalueate metric
            val_preds = val_preds.numpy()
            val_preds[np.where(val_preds >= self.threshold)] = 1
            val_preds[np.where(val_preds < self.threshold)] = 0
            metric_val = self.metric.compute(val_true.numpy(), val_preds)

            self.scheduler.step(avg_val_loss)
            res = self.early_stopping(score=avg_val_loss, model=self.model)

            # print statistics
            if self.hparams['verbose_train']:
                print(
                    '| Epoch: ',
                    epoch + 1,
                    '| Train_loss: ',
                    avg_loss,
                    '| Val_loss: ',
                    avg_val_loss,
                    '| Metric_train: ',
                    metric_train,
                    '| Metric_val: ',
                    metric_val,
                    '| Current LR: ',
                    self.__get_lr(self.optimizer),
                )

            # # add history to tensorboard
            writer.add_scalars(
                'Loss',
                {
                    'Train_loss': avg_loss,
                    'Val_loss': avg_val_loss
                },
                epoch,
            )

            writer.add_scalars('Metric', {
                'Metric_train': metric_train,
                'Metric_val': metric_val
            }, epoch)

            if res == 2:
                print("Early Stopping")
                print(
                    f'global best min val_loss model score {self.early_stopping.best_score}'
                )
                break
            elif res == 1:
                print(f'save global val_loss model score {avg_val_loss}')

        writer.close()

        self.model.zero_grad()

        return True

    def predict(self, X_test):

        # evaluate the model
        self.model.eval()

        test_loader = torch.utils.data.DataLoader(
            X_test,
            batch_size=self.hparams['batch_size'],
            shuffle=False,
            num_workers=self.num_workers)  # ,collate_fn=train.my_collate

        test_preds = torch.Tensor([])
        print('Start generation of predictions')
        with torch.no_grad():
            for i, (X_batch, y_batch) in enumerate(tqdm(test_loader)):
                X_batch = X_batch.float().to(self.device)

                pred = self.model(X_batch)

                X_batch = X_batch.float().cpu().detach()

                test_preds = torch.cat([test_preds, pred.cpu().detach()], 0)

        return test_preds.numpy()

    def get_heatmap(self, X_test):

        # evaluate the model
        self.model.eval()

        test_loader = torch.utils.data.DataLoader(
            X_test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers)  # ,collate_fn=train.my_collate

        test_preds = torch.Tensor([])
        with torch.no_grad():
            for i, (X_batch) in enumerate(test_loader):
                X_batch = X_batch.float().to(self.device)

                pred = self.model.activatations(X_batch)
                pred = torch.sigmoid(pred)

                X_batch = X_batch.float().cpu().detach()

                test_preds = torch.cat([test_preds, pred.cpu().detach()], 0)

        return test_preds.numpy()

    def model_save(self, model_path):
        torch.save(self.model, model_path)
        return True

    def model_load(self, model_path):
        self.model = torch.load(model_path)
        return True

    ################## Utils #####################

    def __get_lr(self, optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']
Esempio n. 13
0
class SRGANModel(BaseModel):
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)

        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        self.netG = DataParallel(self.netG)

        self.netD = networks.define_D(opt).to(self.device)
        self.netD = DataParallel(self.netD)
        if self.is_train:
            self.netG.train()
            self.netD.train()

        if not self.is_train and 'attack' in self.opt:
            # G pixel loss
            if opt['pixel_weight'] > 0:
                l_pix_type = opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if opt['feature_weight'] > 0:
                l_fea_type = opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_w = opt['gan_weight']

        self.delta = 0

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def attack_fgsm(self, is_collect_data=False):
        # collect_data='collect_data' in self.opt['attack'] and self.opt['attack']['collect_data']

        for p in self.netD.parameters():
            p.requires_grad = False
        for p in self.netG.parameters():
            p.requires_grad = False
        self.var_L.requires_grad_()

        self.fake_H = self.netG(self.var_L)

        # l_g_total, l_g_pix, l_g_fea, l_g_gan=self.loss_for_G(self.fake_H,self.var_H,self.var_ref)
        l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)

        # zero_grad
        if self.var_L.grad is not None:
            self.var_L.grad.zero_()
        # self.netG.zero_grad()

        # l_g_total.backward()
        l_g_pix.backward()

        data_grad = self.var_L.grad.data

        sign_data_grad = data_grad.sign()
        perturbed_data = self.var_L + self.opt['attack']['eps'] * sign_data_grad
        perturbed_data = torch.clamp(perturbed_data, 0, 1)

        if is_collect_data:
            init_data = self.var_L.detach()
            self.var_L = perturbed_data.detach()
            perturbed_data = self.var_L.clone().detach()
            return init_data, perturbed_data
        else:
            self.var_L = perturbed_data.detach()
            return

    # TODO test
    def attack_pgd(self, is_collect_data=False):
        eps = self.opt['attack']['eps']

        for p in self.netG.parameters():
            p.requires_grad = False
        orig_input = self.var_L.clone().detach()

        randn = torch.FloatTensor(self.var_L.size()).uniform_(-eps, eps).cuda()
        self.var_L += randn
        self.var_L.clamp_(0, 1.0)

        # self.var_L.requires_grad_()
        # if self.var_L.grad is not None:
        #     self.var_L.grad.zero_()
        self.var_L.detach_()

        for _ in range(self.opt['attack']['step_num']):
            # if self.var_L.grad is not None:
            #     self.var_L.grad.zero_()
            var_L_step = torch.autograd.Variable(self.var_L,
                                                 requires_grad=True)
            self.fake_H = self.netG(var_L_step)
            l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
            l_pix.backward()
            data_grad = var_L_step.grad.data

            pert = self.opt['attack']['step'] * data_grad.sign()
            self.var_L = self.var_L + pert.data
            self.var_L = torch.max(orig_input - eps, self.var_L)
            self.var_L = torch.min(orig_input + eps, self.var_L)
            self.var_L.clamp_(0, 1.0)

        if is_collect_data:
            return orig_input, self.var_L.clone().detach()
        else:
            self.var_L.detach_()
            return

    def feed_data(self, data, need_GT=True, is_collect_data=False):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)

        # TODO attack code start
        if 'attack' in self.opt and need_GT and not (
                'raw_data' in self.opt['attack']
                and self.opt['attack']['raw_data'] == True):
            if 'type' in self.opt['attack'] and self.opt['attack'][
                    'type'] == 'pgd':
                if not is_collect_data:
                    self.attack_pgd()
                else:
                    return self.attack_pgd(is_collect_data=True)
            else:
                if not is_collect_data:
                    self.attack_fgsm()
                else:
                    return self.attack_fgsm(is_collect_data=True)
        # attack code end

    def loss_for_G(self, fake_H, var_H, var_ref):
        l_g_total = 0
        if self.cri_pix:  # pixel loss
            l_g_pix = self.l_pix_w * self.cri_pix(fake_H, var_H)
            l_g_total += l_g_pix
        if self.cri_fea:  # feature loss
            real_fea = self.netF(var_H).detach()
            fake_fea = self.netF(fake_H)
            l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
            l_g_total += l_g_fea
        if self.l_gan_w > 0.0:
            if ('train' in self.opt and self.opt['train']['gan_type']
                    == 'gan') or ('attack' in self.opt
                                  and self.opt['gan_type'] == 'gan'):
                pred_g_fake = self.netD(fake_H)
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif ('train' in self.opt and self.opt['train']['gan_type']
                  == 'ragan') or ('attack' in self.opt
                                  and self.opt['gan_type'] == 'ragan'):
                pred_d_real = self.netD(var_ref).detach()
                pred_g_fake = self.netD(fake_H)
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan
        else:
            l_g_gan = torch.tensor(0.0)
        return l_g_total, l_g_pix, l_g_fea, l_g_gan

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False
        for p in self.netG.parameters():
            p.requires_grad = True
        if 'adv_train' in self.opt:
            self.var_L.requires_grad_()
            if self.var_L.grad is not None:
                self.var_L.grad.data.zero_()

        if 'adv_train' not in self.opt:
            self.fake_H = self.netG(self.var_L)
        else:
            self.fake_H = self.netG(torch.clamp(self.var_L + self.delta, 0, 1))

        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if 'adv_train' not in self.opt:
                l_g_total, l_g_pix, l_g_fea, l_g_gan = self.loss_for_G(
                    self.fake_H, self.var_H, self.var_ref)

                self.optimizer_G.zero_grad()

                l_g_total.backward()
                self.optimizer_G.step()
            else:
                for _ in range(self.opt['adv_train']['m']):
                    l_g_total, l_g_pix, l_g_fea, l_g_gan = self.loss_for_G(
                        self.fake_H, self.var_H, self.var_ref)

                    self.optimizer_G.zero_grad()
                    if self.var_L.grad is not None:
                        self.var_L.grad.data.zero_()

                    l_g_total.backward()
                    self.optimizer_G.step()

                    self.delta = self.delta + \
                        self.opt['adv_train']['step'] * \
                        self.var_L.grad.data.sign()
                    self.delta.clamp_(-self.opt['attack']['eps'],
                                      self.opt['attack']['eps'])
                    self.fake_H = self.netG(
                        torch.clamp(self.var_L + self.delta, 0, 1))
        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        if self.opt['train']['gan_type'] == 'gan':
            # need to forward and backward separately, since batch norm statistics differ
            # real
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_real.backward()
            # fake
            # detach to avoid BP to G
            pred_d_fake = self.netD(self.fake_H.detach())
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_fake.backward()
        elif self.opt['train']['gan_type'] == 'ragan':
            pred_d_fake = self.netD(self.fake_H.detach()).detach()
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True) * 0.5
            l_d_real.backward()
            pred_d_fake = self.netD(self.fake_H.detach())
            l_d_fake = self.cri_gan(
                pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
            l_d_fake.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            self.log_dict['l_g_total'] = l_g_total.item()
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()

            self.log_dict['l_d_real'] = l_d_real.item()
            self.log_dict['l_d_fake'] = l_d_fake.item()
            self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
            self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        logger.info('Network G structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            logger.info(
                'Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                logger.info(
                    'Network F structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])
        load_path_D = self.opt['path']['pretrain_model_D']
        if load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])
        load_path_F = self.opt['path']['pretrain_model_F']
        if load_path_F is not None:
            logger.info('Loading model for F [{:s}] ...'.format(load_path_F))
            network = self.netF.module.features
            if isinstance(network, nn.DataParallel):
                network = network.module
            load_net = torch.load(load_path_F)
            load_net_clean = OrderedDict()  # remove unnecessary 'module.'
            for k, v in load_net.items():
                if k.startswith('module.features.'):
                    load_net_clean[k[16:]] = v
            network.load_state_dict(load_net_clean,
                                    strict=self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
Esempio n. 14
0
class CLSGAN_Model(BaseModel):
    def __init__(self, opt):
        super(CLSGAN_Model, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        G_opt = opt['network_G']

        # define networks and load pretrained models
        self.netG = RCAN(G_opt).to(self.device)
        self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = Discriminator_VGG_256(3, G_opt['nf']).to(self.device)
            self.netD = DataParallel(self.netD)
            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = VGGFeatureExtractor(feature_layer=34,
                                                use_bn=False,
                                                use_input_norm=True,
                                                device=self.device).to(
                                                    self.device)
                self.netF = DataParallel(self.netF)

            # G feature loss
            if train_opt['cls_weight'] > 0:
                l_cls_type = train_opt['cls_criterion']
                if l_cls_type == 'CE':
                    self.cri_cls = nn.NLLLoss().to(self.device)
                elif l_cls_type == 'l1':
                    self.cri_cls = nn.L1Loss().to(self.device)
                elif l_cls_type == 'l2':
                    self.cri_cls = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_cls_type))
                self.l_cls_w = train_opt['cls_weight']
            else:
                logger.info('Remove classification loss.')
                self.cri_cls = None
            if self.cri_cls:  # load VGG perceptual loss
                self.netC = VGGFeatureExtractor(feature_layer=49,
                                                use_bn=True,
                                                use_input_norm=True,
                                                device=self.device).to(
                                                    self.device)
                load_path_C = self.opt['path']['pretrain_model_C']
                assert load_path_C is not None, "Must get Pretrained Classfication prior."
                self.netC.load_model(load_path_C)
                self.netC = DataParallel(self.netC)

            if train_opt['brc_weight'] > 0:
                self.l_brc_w = train_opt['brc_weight']
                self.netR = VGG_Classifier().to(self.device)
                load_path_C = self.opt['path']['pretrain_model_C']
                assert load_path_C is not None, "Must get Pretrained Classfication prior."
                self.netR.load_model(load_path_C)
                self.netR = DataParallel(self.netR)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()
        self.fake_H, self.cls_L = self.netG(self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            if self.cri_cls:  # F-G classification loss
                #print(self.netC(self.var_H).detach().shape)
                #real_cls = self.netC(self.var_H).argmax(1).detach()
                #fake_cls = torch.log( nn.Softmax(dim=1) (self.netC(self.fake_H)) )
                real_cls = self.netC(self.var_H).detach()
                fake_cls = self.netC(self.fake_H)
                l_g_cls = self.l_cls_w * self.cri_cls(fake_cls, real_cls)
                l_g_total = l_g_cls
            if self.opt['train']['gan_type'] == 'gan':
                pred_g_fake = self.netD(self.fake_H)
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.opt['train']['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                pred_g_fake = self.netD(self.fake_H)
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan

            if self.opt['train']['br_optimizer'] == 'joint':
                ref = self.netR(self.var_H).argmax(dim=1)
                l_branch = self.l_brc_w * nn.CrossEntropyLoss()(self.cls_L,
                                                                ref)
                self.optimizer_G.step()

            l_g_total.backward()
            self.optimizer_G.step()

            self.optimizer_G.zero_grad()

            # seperate branching update
            if self.opt['train']['br_optimizer'] == 'branch':
                ref = self.netR(self.var_H).argmax(dim=1)
                l_branch = self.l_brc_w * nn.CrossEntropyLoss()(self.cls_L,
                                                                ref)
                self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        if self.opt['train']['gan_type'] == 'gan':
            # need to forward and backward separately, since batch norm statistics differ
            # real
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_real.backward()
            # fake
            pred_d_fake = self.netD(
                self.fake_H.detach())  # detach to avoid BP to G
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_fake.backward()
        elif self.opt['train']['gan_type'] == 'ragan':
            # pred_d_real = self.netD(self.var_ref)
            # pred_d_fake = self.netD(self.fake_H.detach())  # detach to avoid BP to G
            # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
            # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
            # l_d_total = (l_d_real + l_d_fake) / 2
            # l_d_total.backward()
            pred_d_fake = self.netD(self.fake_H.detach()).detach()
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True) * 0.5
            l_d_real.backward()
            pred_d_fake = self.netD(self.fake_H.detach())
            l_d_fake = self.cri_gan(
                pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
            l_d_fake.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H, _ = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network F structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

            if self.cri_cls:  # C, F-G Classification Network
                s, n = self.get_network_description(self.netC)
                if isinstance(self.netC, nn.DataParallel) or isinstance(
                        self.netC, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netC.__class__.__name__,
                        self.netC.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netC.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network C structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)

    def clear_data(self):
        return None
Esempio n. 15
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               size=crop_size,
                                               scale=args.scale,
                                               coarse=args.coarse)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=crop_size,
                                             scale=args.scale,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0

    elif args.dataset == 'greenhouse':
        print(args.use_depth)
        from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation, GreenhouseDepth, GREENHOUSE_CLASS_LIST
        train_dataset = GreenhouseDepth(root=args.data_path,
                                        list_name='train_depth_ae.txt',
                                        train=True,
                                        size=crop_size,
                                        scale=args.scale,
                                        use_filter=True)
        val_dataset = GreenhouseRGBDSegmentation(root=args.data_path,
                                                 list_name='val_depth_ae.txt',
                                                 train=False,
                                                 size=crop_size,
                                                 scale=args.scale,
                                                 use_depth=True)
        class_weights = np.load('class_weights.npy')[:4]
        print(class_weights)
        class_wts = torch.from_numpy(class_weights).float().to(device)

        seg_classes = len(GREENHOUSE_CLASS_LIST)
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    from model.autoencoder.depth_autoencoder import espnetv2_autoenc
    args.classes = 3
    model = espnetv2_autoenc(args)

    train_params = [{
        'params': model.get_basenet_params(),
        'lr': args.lr * args.lr_mult
    }]

    optimizer = optim.SGD(train_params,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 1, crop_size[0], crop_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[0], crop_size[1], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.Tensor(1, 3, crop_size[0],
                                                     crop_size[1]))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    start_epoch = 0

    print('device : ' + device)

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    #criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type,
    #                             device=device, ignore_idx=args.ignore_idx,
    #                             class_wts=class_wts.to(device))
    criterion = nn.MSELoss()
    # criterion = nn.L1Loss()

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    if args.scheduler == 'fixed':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler Not supported'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
    best_loss = 0.0
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_seg
        # optimizer.param_groups[1]['lr'] = lr_seg

        # Train
        model.train()
        losses = AverageMeter()
        for i, batch in enumerate(train_loader):
            inputs = batch[1].to(device=device)  # Depth
            target = batch[0].to(device=device)  # RGB

            outputs = model(inputs)

            if device == 'cuda':
                loss = criterion(outputs, target).mean()
                if isinstance(outputs, (list, tuple)):
                    target_dev = outputs[0].device
                    outputs = gather(outputs, target_device=target_dev)
            else:
                loss = criterion(outputs, target)

            losses.update(loss.item(), inputs.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #             if not (i % 10):
            #                 print("Step {}, write images".format(i))
            #                 image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy()
            #                 writer.add_image('Autoencoder/results/train', image_grid, len(train_loader) * epoch + i)

            writer.add_scalar('Autoencoder/Loss/train', loss.item(),
                              len(train_loader) * epoch + i)

            print_info_message('Running batch {}/{} of epoch {}'.format(
                i + 1, len(train_loader), epoch + 1))

        train_loss = losses.avg

        writer.add_scalar('Autoencoder/LR/seg', round(lr_seg, 6), epoch)

        # Val
        if epoch % 5 == 0:
            losses = AverageMeter()
            with torch.no_grad():
                for i, batch in enumerate(val_loader):
                    inputs = batch[2].to(device=device)  # Depth
                    target = batch[0].to(device=device)  # RGB

                    outputs = model(inputs)

                    if device == 'cuda':
                        loss = criterion(outputs, target)  # .mean()
                        if isinstance(outputs, (list, tuple)):
                            target_dev = outputs[0].device
                            outputs = gather(outputs, target_device=target_dev)
                    else:
                        loss = criterion(outputs, target)

                    losses.update(loss.item(), inputs.size(0))

                    image_grid = torchvision.utils.make_grid(
                        outputs.data.cpu()).numpy()
                    writer.add_image('Autoencoder/results/val', image_grid,
                                     epoch)
                    image_grid = torchvision.utils.make_grid(
                        inputs.data.cpu()).numpy()
                    writer.add_image('Autoencoder/inputs/val', image_grid,
                                     epoch)
                    image_grid = torchvision.utils.make_grid(
                        target.data.cpu()).numpy()
                    writer.add_image('Autoencoder/target/val', image_grid,
                                     epoch)

            val_loss = losses.avg

            print_info_message(
                'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
                .format(epoch, lr_base, lr_seg))

            # remember best miou and save checkpoint
            is_best = val_loss < best_loss
            best_loss = min(val_loss, best_loss)

            weights_dict = model.module.state_dict(
            ) if device == 'cuda' else model.state_dict()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': weights_dict,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.savedir, extra_info_ckpt)

            writer.add_scalar('Autoencoder/Loss/val', val_loss, epoch)

    writer.close()
Esempio n. 16
0
class SRGANModel(BaseModel):
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.train_opt = train_opt
        self.opt = opt

        self.segmentor = None

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if train_opt.get("gan_video_weight", 0) > 0:
                self.net_video_D = networks.define_video_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
                if train_opt.get("gan_video_weight", 0) > 0:
                    self.net_video_D = DistributedDataParallel(
                        self.net_video_D,
                        device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)
                if train_opt.get("gan_video_weight", 0) > 0:
                    self.net_video_D = DataParallel(self.net_video_D)

            self.netG.train()
            self.netD.train()
            if train_opt.get("gan_video_weight", 0) > 0:
                self.net_video_D.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # Pixel mask loss
            if train_opt.get("pixel_mask_weight", 0) > 0:
                l_pix_type = train_opt['pixel_mask_criterion']
                self.cri_pix_mask = LMaskLoss(
                    l_pix_type=l_pix_type,
                    segm_mask=train_opt['segm_mask']).to(self.device)
                self.l_pix_mask_w = train_opt['pixel_mask_weight']
            else:
                logger.info('Remove pixel mask loss.')
                self.cri_pix_mask = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # Video gan weight
            if train_opt.get("gan_video_weight", 0) > 0:
                self.cri_video_gan = GANLoss(train_opt['gan_video_type'], 1.0,
                                             0.0).to(self.device)
                self.l_gan_video_w = train_opt['gan_video_weight']

                # can't use optical flow with i and i+1 because we need i+2 lr to calculate i+1 oflow
                if 'train' in self.opt['datasets'].keys():
                    key = "train"
                else:
                    key = 'test_1'
                assert self.opt['datasets'][key][
                    'optical_flow_with_ref'] == True, f"Current value = {self.opt['datasets'][key]['optical_flow_with_ref']}"
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # Video D
            if train_opt.get("gan_video_weight", 0) > 0:
                self.optimizer_video_D = torch.optim.Adam(
                    self.net_video_D.parameters(),
                    lr=train_opt['lr_D'],
                    weight_decay=wd_D,
                    betas=(train_opt['beta1_D'], train_opt['beta2_D']))
                self.optimizers.append(self.optimizer_video_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def feed_data(self, data, need_GT=True):
        self.img_path = data['GT_path']
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
        if self.train_opt.get("use_HR_ref"):
            self.var_HR_ref = data['img_reference'].to(self.device)
        if "LQ_next" in data.keys():
            self.var_L_next = data['LQ_next'].to(self.device)
            if "GT_next" in data.keys():
                self.var_H_next = data['GT_next'].to(self.device)
                self.var_video_H = torch.cat(
                    [data['GT'].unsqueeze(2), data['GT_next'].unsqueeze(2)],
                    dim=2).to(self.device)
        else:
            self.var_L_next = None

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()

        args = [self.var_L]
        if self.train_opt.get('use_HR_ref'):
            args += [self.var_HR_ref]
        if self.var_L_next is not None:
            args += [self.var_L_next]
        self.fake_H, self.binary_mask = self.netG(*args)

        #Video Gan
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            with torch.no_grad():
                args = [self.var_L, self.var_HR_ref, self.var_L_next]
                self.fake_H_next, self.binary_mask_next = self.netG(*args)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_pix_mask:
                l_g_pix_mask = self.l_pix_mask_w * self.cri_pix_mask(
                    self.fake_H, self.var_H, self.var_HR_ref)
                l_g_total += l_g_pix_mask
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            # Image Gan
            if self.opt['network_D'] == "discriminator_vgg_128_mask":
                import torch.nn.functional as F
                from models.modules import psina_seg
                if self.segmentor is None:
                    self.segmentor = psina_seg.base.SegmentationModule(
                        encode='stationary_probs').to(self.device)
                self.segmentor = self.segmentor.eval()
                lr = F.interpolate(self.var_H,
                                   scale_factor=0.25,
                                   mode='nearest')
                with torch.no_grad():
                    binary_mask = (
                        1 - self.segmentor.predict(lr[:, [2, 1, 0], ::]))
                binary_mask = F.interpolate(binary_mask,
                                            scale_factor=4,
                                            mode='nearest')
                pred_g_fake = self.netD(self.fake_H,
                                        self.fake_H * (1 - binary_mask),
                                        self.var_HR_ref,
                                        binary_mask * self.var_HR_ref)
            else:
                pred_g_fake = self.netD(self.fake_H)

            if self.opt['train']['gan_type'] == 'gan':
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.opt['train']['gan_type'] == 'ragan':
                if self.opt['network_D'] == "discriminator_vgg_128_mask":
                    pred_g_fake = self.netD(self.var_H,
                                            self.var_H * (1 - binary_mask),
                                            self.var_HR_ref,
                                            binary_mask * self.var_HR_ref)
                else:
                    pred_d_real = self.netD(self.var_H)
                pred_d_real = pred_d_real.detach()
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan

            #Video Gan
            if self.opt['train'].get("gan_video_weight", 0) > 0:
                self.fake_video_H = torch.cat(
                    [self.fake_H.unsqueeze(2),
                     self.fake_H_next.unsqueeze(2)],
                    dim=2)
                pred_g_video_fake = self.net_video_D(self.fake_video_H)
                if self.opt['train']['gan_video_type'] == 'gan':
                    l_g_video_gan = self.l_gan_video_w * self.cri_video_gan(
                        pred_g_video_fake, True)
                elif self.opt['train']['gan_type'] == 'ragan':
                    pred_d_video_real = self.net_video_D(self.var_video_H)
                    pred_d_video_real = pred_d_video_real.detach()
                    l_g_video_gan = self.l_gan_video_w * (self.cri_video_gan(
                        pred_d_video_real - torch.mean(pred_g_video_fake),
                        False) + self.cri_video_gan(
                            pred_g_video_fake - torch.mean(pred_d_video_real),
                            True)) / 2
                l_g_total += l_g_video_gan

            # OFLOW regular
            if self.binary_mask is not None:
                l_g_total += 1 * self.binary_mask.mean()

            l_g_total.backward()
            self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            for p in self.net_video_D.parameters():
                p.requires_grad = True

        # optimize Image D
        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_H)
        pred_d_fake = self.netD(
            self.fake_H.detach())  # detach to avoid BP to G
        if self.opt['train']['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.opt['train']['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real + l_d_fake) / 2
        l_d_total.backward()
        self.optimizer_D.step()

        # optimize Video D
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            self.optimizer_video_D.zero_grad()
            l_d_video_total = 0
            pred_d_video_real = self.net_video_D(self.var_video_H)
            pred_d_video_fake = self.net_video_D(
                self.fake_video_H.detach())  # detach to avoid BP to G
            if self.opt['train']['gan_video_type'] == 'gan':
                l_d_video_real = self.cri_video_gan(pred_d_video_real, True)
                l_d_video_fake = self.cri_video_gan(pred_d_video_fake, False)
                l_d_video_total = l_d_video_real + l_d_video_fake
            elif self.opt['train']['gan_video_type'] == 'ragan':
                l_d_video_real = self.cri_video_gan(
                    pred_d_video_real - torch.mean(pred_d_video_fake), True)
                l_d_video_fake = self.cri_video_gan(
                    pred_d_video_fake - torch.mean(pred_d_video_real), False)
                l_d_video_total = (l_d_video_real + l_d_video_fake) / 2
            l_d_video_total.backward()
            self.optimizer_video_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            self.log_dict['D_video_real'] = torch.mean(
                pred_d_video_real.detach())
            self.log_dict['D_video_fake'] = torch.mean(
                pred_d_video_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            args = [self.var_L]
            if self.train_opt.get('use_HR_ref'):
                args += [self.var_HR_ref]
            if self.var_L_next is not None:
                args += [self.var_L_next]
            self.fake_H, self.binary_mask = self.netG(*args)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if self.binary_mask is not None:
            out_dict['binary_mask'] = self.binary_mask.detach()[0].float().cpu(
            )
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network F structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        # G
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['pretrain_model_G_strict_load'])

        if self.opt['network_G'].get("pretrained_net") is not None:
            self.netG.module.load_pretrained_net_weights(
                self.opt['network_G']['pretrained_net'])

        # D
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['pretrain_model_D_strict_load'])

        # Video D
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            load_path_video_D = self.opt['path'].get("pretrain_model_video_D")
            if self.opt['is_train'] and load_path_video_D is not None:
                self.load_network(
                    load_path_video_D, self.net_video_D,
                    self.opt['path']['pretrain_model_video_D_strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            self.save_network(self.net_video_D, 'video_D', iter_step)

    @staticmethod
    def _freeze_net(network):
        for p in network.parameters():
            p.requires_grad = False
        return network

    @staticmethod
    def _unfreeze_net(network):
        for p in network.parameters():
            p.requires_grad = True
        return network

    def freeze(self, G, D):
        if G:
            self.netG.module.net = self._freeze_net(self.netG.module.net)
        if D:
            self.netD.module = self._freeze_net(self.netD.module)

    def unfreeze(self, G, D):
        if G:
            self.netG.module.net = self._unfreeze_net(self.netG.module.net)
        if D:
            self.netD.module = self._unfreeze_net(self.netD.module)
Esempio n. 17
0
class Trainer(object):
    def __init__(self,
                 batch=8,
                 subdivisions=4,
                 epochs=100,
                 burn_in=1000,
                 steps=[400000, 450000]):

        _model = build_from_dict(model, DETECTORS)
        self.model = DataParallel(_model.cuda(), device_ids=[0])

        self.train_dataset = build_from_dict(data_cfg['train'], DATASET)
        self.val_dataset = build_from_dict(data_cfg['val'], DATASET)

        self.burn_in = burn_in
        self.steps = steps
        self.epochs = epochs

        self.batch = batch
        self.subdivisions = subdivisions

        self.train_size = len(self.train_dataset)
        self.val_size = len(self.val_dataset)

        self.train_loader = DataLoader(self.train_dataset,
                                       batch_size=batch // subdivisions,
                                       shuffle=True,
                                       num_workers=1,
                                       pin_memory=True,
                                       drop_last=True,
                                       collate_fn=self.collate)

        self.val_loader = DataLoader(self.val_dataset,
                                     batch_size=batch // subdivisions,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True,
                                     drop_last=True,
                                     collate_fn=self.collate)

        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=0.001 / batch,
            betas=(0.9, 0.999),
            eps=1e-08,
        )

        self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer,
                                                     self.burnin_schedule)

    def train(self):
        self.model.train()
        global_step = 0
        checkpoints = r'/disk2/project/pytorch-YOLOv4/checkpoints/'
        save_prefix = 'Yolov4_epoch_'
        saved_models = collections.deque()
        for epoch in range(self.epochs):

            epoch_loss = 0
            epoch_step = 0

            for i, batch in enumerate(self.train_loader):
                losses = self.model(**batch)
                loss = self.parse_losses(losses)
                loss.backward()
                epoch_loss += loss.item()
                print('loss :{}'.format(loss))

                global_step += 1
                epoch_step += 1

                if global_step % self.subdivisions == 0:
                    self.optimizer.zero_grad()
                    self.optimizer.step()
                    self.scheduler.step()

            try:
                # os.mkdir(config.checkpoints)
                os.makedirs(checkpoints, exist_ok=True)
            except OSError:
                pass
            save_path = os.path.join(checkpoints,
                                     f'{save_prefix}{epoch + 1}.pth')
            torch.save(model.state_dict(), save_path)

            saved_models.append(save_path)
            if len(saved_models) > 5:
                model_to_remove = saved_models.popleft()
                try:
                    os.remove(model_to_remove)
                except:
                    pass

    def burnin_schedule(self, i):
        if i < self.burn_in:
            factor = pow(i / self.burn_in, 4)
        elif i < self.steps[0]:
            factor = 1.0
        elif i < self.steps[1]:
            factor = 0.1
        else:
            factor = 0.01
        return factor

    def collate(self, batch):
        if 'multi_scale' in data_cfg.keys() and len(
                data_cfg['multi_scale']) > 0:
            multi_scale = data_cfg['multi_scale']
            if isinstance(multi_scale, dict) and 'type' in multi_scale.keys():
                randomShape = build_from_dict(multi_scale, TRANSFORMS)
                batch = randomShape(batch)
        collate = default_collate(batch)
        return collate

    def parse_losses(self, losses):
        log_vars = collections.OrderedDict()
        for loss_name, loss_value in losses.items():
            if isinstance(loss_value, torch.Tensor):
                log_vars[loss_name] = loss_value.mean()
            elif isinstance(loss_value, list):
                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
            else:
                raise TypeError(
                    '{} is not a tensor or list of tensors'.format(loss_name))

        loss = sum(_value for _key, _value in log_vars.items()
                   if 'loss' in _key)

        return loss
Esempio n. 18
0
class SRFlowModel(BaseModel):
    def __init__(self, opt, step):
        super(SRFlowModel, self).__init__(opt)
        self.opt = opt

        self.heats = opt['val']['heats']
        self.n_sample = opt['val']['n_sample']
        self.hr_size = opt_get(opt,
                               ['datasets', 'train', 'center_crop_hr_size'])
        self.hr_size = 160 if self.hr_size is None else self.hr_size
        self.lr_size = self.hr_size // opt['scale']

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_Flow(opt, step).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()

        if opt_get(opt, ['path', 'resume_state'], 1) is not None:
            self.load()
        else:
            print(
                "WARNING: skipping initial loading, due to resume_state None")

        if self.is_train:
            self.netG.train()

            self.init_optimizer_and_scheduler(train_opt)
            self.log_dict = OrderedDict()

    def to(self, device):
        self.device = device
        self.netG.to(device)

    def init_optimizer_and_scheduler(self, train_opt):
        # optimizers
        self.optimizers = []
        wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
        optim_params_RRDB = []
        optim_params_other = []
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            print(k, v.requires_grad)
            if v.requires_grad:
                if '.RRDB.' in k:
                    optim_params_RRDB.append(v)
                    print('opt', k)
                else:
                    optim_params_other.append(v)
                if self.rank <= 0:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))

        print('rrdb params', len(optim_params_RRDB))

        self.optimizer_G = torch.optim.Adam(
            [{
                "params": optim_params_other,
                "lr": train_opt['lr_G'],
                'beta1': train_opt['beta1'],
                'beta2': train_opt['beta2'],
                'weight_decay': wd_G
            }, {
                "params": optim_params_RRDB,
                "lr": train_opt.get('lr_RRDB', train_opt['lr_G']),
                'beta1': train_opt['beta1'],
                'beta2': train_opt['beta2'],
                'weight_decay': wd_G
            }], )

        self.optimizers.append(self.optimizer_G)
        # schedulers
        if train_opt['lr_scheme'] == 'MultiStepLR':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.MultiStepLR_Restart(
                        optimizer,
                        train_opt['lr_steps'],
                        restarts=train_opt['restarts'],
                        weights=train_opt['restart_weights'],
                        gamma=train_opt['lr_gamma'],
                        clear_state=train_opt['clear_state'],
                        lr_steps_invese=train_opt.get('lr_steps_inverse', [])))
        elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.CosineAnnealingLR_Restart(
                        optimizer,
                        train_opt['T_period'],
                        eta_min=train_opt['eta_min'],
                        restarts=train_opt['restarts'],
                        weights=train_opt['restart_weights']))
        else:
            raise NotImplementedError(
                'MultiStepLR learning rate scheme is enough.')

    def add_optimizer_and_scheduler_RRDB(self, train_opt):
        # optimizers
        assert len(self.optimizers) == 1, self.optimizers
        assert len(self.optimizer_G.param_groups[1]
                   ['params']) == 0, self.optimizer_G.param_groups[1]
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            if v.requires_grad:
                if '.RRDB.' in k:
                    self.optimizer_G.param_groups[1]['params'].append(v)
        assert len(self.optimizer_G.param_groups[1]['params']) > 0

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.real_H = data['GT'].to(self.device)  # GT

    def optimize_parameters(self, step):

        train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
        if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \
                and not self.netG.module.RRDB_training:
            if self.netG.module.set_rrdb_training(True):
                self.add_optimizer_and_scheduler_RRDB(self.opt['train'])

        # self.print_rrdb_state()

        self.netG.train()
        self.log_dict = OrderedDict()
        self.optimizer_G.zero_grad()

        losses = {}
        weight_fl = opt_get(self.opt, ['train', 'weight_fl'])
        weight_fl = 1 if weight_fl is None else weight_fl
        if weight_fl > 0:
            #print('self.var_L: ', self.var_L, self.var_L.shape)
            #print('self.real_H: ', self.real_H, self.real_H.shape)
            z, nll, y_logits = self.netG(gt=self.real_H,
                                         lr=self.var_L,
                                         reverse=False)
            nll_loss = torch.mean(nll)
            losses['nll_loss'] = nll_loss * weight_fl
            #print('nll_loss: ', nll_loss)

        weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0
        if weight_l1 > 0:
            z = self.get_z(heat=0,
                           seed=None,
                           batch_size=self.var_L.shape[0],
                           lr_shape=self.var_L.shape)
            sr, logdet = self.netG(lr=self.var_L,
                                   z=z,
                                   eps_std=0,
                                   reverse=True,
                                   reverse_with_grad=True)
            l1_loss = (sr - self.real_H).abs().mean()
            losses['l1_loss'] = l1_loss * weight_l1
            #print('l1_loss: ', l1_loss)

        total_loss = sum(losses.values())
        #print('total_loss: ', total_loss)
        # total_loss:  tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)
        # ERROR: RuntimeError: svd_cuda: the updating process of SBDSDC did not converge (error: 11)

        total_loss.backward()
        self.optimizer_G.step()

        mean = total_loss.item()
        return mean

    def print_rrdb_state(self):
        for name, param in self.netG.module.named_parameters():
            if "RRDB.conv_first.weight" in name:
                print(name, param.requires_grad, param.data.abs().sum())
        print('params',
              [len(p['params']) for p in self.optimizer_G.param_groups])

    def test(self):
        self.netG.eval()
        self.fake_H = {}
        for heat in self.heats:
            for i in range(self.n_sample):
                z = self.get_z(heat,
                               seed=None,
                               batch_size=self.var_L.shape[0],
                               lr_shape=self.var_L.shape)
                with torch.no_grad():
                    self.fake_H[(heat, i)], logdet = self.netG(lr=self.var_L,
                                                               z=z,
                                                               eps_std=heat,
                                                               reverse=True)
        with torch.no_grad():
            _, nll, _ = self.netG(gt=self.real_H, lr=self.var_L, reverse=False)
        self.netG.train()
        return nll.mean().item()

    def get_encode_nll(self, lq, gt):
        self.netG.eval()
        with torch.no_grad():
            _, nll, _ = self.netG(gt=gt, lr=lq, reverse=False)
        self.netG.train()
        return nll.mean().item()

    def get_sr(self, lq, heat=None, seed=None, z=None, epses=None):
        return self.get_sr_with_z(lq, heat, seed, z, epses)[0]

    def get_encode_z(self, lq, gt, epses=None, add_gt_noise=True):
        self.netG.eval()
        with torch.no_grad():
            z, _, _ = self.netG(gt=gt,
                                lr=lq,
                                reverse=False,
                                epses=epses,
                                add_gt_noise=add_gt_noise)
        self.netG.train()
        return z

    def get_encode_z_and_nll(self, lq, gt, epses=None, add_gt_noise=True):
        self.netG.eval()
        with torch.no_grad():
            z, nll, _ = self.netG(gt=gt,
                                  lr=lq,
                                  reverse=False,
                                  epses=epses,
                                  add_gt_noise=add_gt_noise)
        self.netG.train()
        return z, nll

    def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None):
        self.netG.eval()

        z = self.get_z(heat, seed, batch_size=lq.shape[0],
                       lr_shape=lq.shape) if z is None and epses is None else z

        with torch.no_grad():
            sr, logdet = self.netG(lr=lq,
                                   z=z,
                                   eps_std=heat,
                                   reverse=True,
                                   epses=epses)
        self.netG.train()
        return sr, z

    def get_z(self, heat, seed=None, batch_size=1, lr_shape=None):
        if seed: torch.manual_seed(seed)
        if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']):
            C = self.netG.module.flowUpsamplerNet.C
            H = int(self.opt['scale'] * lr_shape[2] //
                    self.netG.module.flowUpsamplerNet.scaleH)
            W = int(self.opt['scale'] * lr_shape[3] //
                    self.netG.module.flowUpsamplerNet.scaleW)
            z = torch.normal(mean=0, std=heat,
                             size=(batch_size, C, H,
                                   W)) if heat > 0 else torch.zeros(
                                       (batch_size, C, H, W))
        else:
            L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3
            fac = 2**(L - 3)
            z_size = int(self.lr_size // (2**(L - 3)))
            z = torch.normal(mean=0,
                             std=heat,
                             size=(batch_size, 3 * 8 * 8 * fac * fac, z_size,
                                   z_size))
        return z

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        for heat in self.heats:
            for i in range(self.n_sample):
                out_dict[('SR', heat,
                          i)] = self.fake_H[(heat,
                                             i)].detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        _, get_resume_model_path = get_resume_paths(self.opt)
        if get_resume_model_path is not None:
            self.load_network(get_resume_model_path,
                              self.netG,
                              strict=True,
                              submodule=None)
            return

        load_path_G = self.opt['path']['pretrain_model_G']
        load_submodule = self.opt['path'][
            'load_submodule'] if 'load_submodule' in self.opt['path'].keys(
            ) else 'RRDB'
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G,
                              self.netG,
                              self.opt['path'].get('strict_load', True),
                              submodule=load_submodule)

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
Esempio n. 19
0
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x)


model = DataParallel(Net())
model.cuda()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.NLLLoss().cuda()

model.train()
for batch_idx, (data, target) in enumerate(train_loader):
    input_var = Variable(data.cuda())
    target_var = Variable(target.cuda())

    print('Getting model output')
    output = model(input_var)
    print('Got model output')

    loss = criterion(output, target_var)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print('Finished')
Esempio n. 20
0
class VideoSRBaseModel(BaseModel):
    def __init__(self, opt):
        super(VideoSRBaseModel, self).__init__(opt)

        if opt["dist"]:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt["train"]

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt["dist"]:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            #### loss
            loss_type = train_opt["pixel_criterion"]
            if loss_type == "l1":
                self.cri_pix = nn.L1Loss(reduction="sum").to(self.device)
            elif loss_type == "l2":
                self.cri_pix = nn.MSELoss(reduction="sum").to(self.device)
            elif loss_type == "cb":
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    "Loss type [{:s}] is not recognized.".format(loss_type))
            self.cri_aligned = (nn.L1Loss(reduction="sum").to(self.device)
                                if train_opt["aligned_criterion"] else None)
            self.l_pix_w = train_opt["pixel_weight"]

            #### optimizers
            wd_G = train_opt["weight_decay_G"] if train_opt[
                "weight_decay_G"] else 0
            if train_opt["ft_tsa_only"]:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if "tsa_fusion" in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                "Params [{:s}] will not optimize.".format(k))
                optim_params = [
                    {  # add normal params first
                        "params": normal_params,
                        "lr": train_opt["lr_G"],
                    },
                    {"params": tsa_fusion_params, "lr": train_opt["lr_G"]},
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                "Params [{:s}] will not optimize.".format(k))

            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1"], train_opt["beta2"]),
            )
            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt["lr_steps"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                            gamma=train_opt["lr_gamma"],
                            clear_state=train_opt["clear_state"],
                        ))
            elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt["T_period"],
                            eta_min=train_opt["eta_min"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                        ))
            else:
                raise NotImplementedError()

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data["LQs"].to(self.device)
        if need_GT:
            self.real_H = data["GT"].to(self.device)

    def set_params_lr_zero(self):
        # fix normal module
        self.optimizers[0].param_groups[0]["lr"] = 0

    def optimize_parameters(self, step):
        if self.opt["train"][
                "ft_tsa_only"] and step < self.opt["train"]["ft_tsa_only"]:
            self.set_params_lr_zero()

        train_opt = self.opt["train"]
        opt_net = self.opt["network_G"]

        self.optimizer_G.zero_grad()
        self.fake_H, aligned_fea = self.netG(self.var_L)

        l_total = 0

        # Pixel loss
        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        l_total += l_pix

        # Aligned loss
        B, N, C, H, W = self.var_L.size()  # N video frames
        center = N // 2
        nf = opt_net["nf"]
        fea2imgConv = nn.Conv2d(nf, 3, 3, 1, 1)
        fea2imgConv.eval()
        # Fix bug: Input type and weight type should be the same
        # Feature is cuda(), so the model must be cuda()
        fea2imgConv.cuda()

        # Stack N of center LR images
        var_L_center = self.var_L[:, center, :, :, :].contiguous()
        var_L_center_expanded = var_L_center.expand(1, -1, -1, -1, -1)
        var_L_center_repeated = var_L_center_expanded.repeat(N, 1, 1, 1, 1)
        var_L_stacked_center = torch.transpose(var_L_center_repeated, 0, 1)

        # Assign center frame to center aligned feature
        with torch.no_grad():
            aligned_img = fea2imgConv(aligned_fea.view(-1, nf, H, W)).view(
                B, N, -1, H, W)
        aligned_img[:, center, :, :, :] = var_L_center

        l_aligned = (1 / (N - 1) *
                     self.cri_aligned(aligned_img, var_L_stacked_center)
                     if train_opt["aligned_criterion"] else 0)
        l_total += l_aligned

        l_total.backward()

        self.optimizer_G.step()

        # set log
        self.log_dict["l_pix"] = l_pix.item()
        if train_opt["aligned_criterion"]:
            self.log_dict["l_aligned"] = l_aligned.item()
            self.log_dict["l_total"] = l_total.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict["LQ"] = self.var_L.detach()[0].float().cpu()
        out_dict["restore"] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict["GT"] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = "{} - {}".format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = "{}".format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                "Network G structure: {}, with parameters: {:,d}".format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt["path"]["pretrain_model_G"]
        if load_path_G is not None:
            logger.info("Loading model for G [{:s}] ...".format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt["path"]["strict_load"])

    def save(self, iter_label):
        self.save_network(self.netG, "G", iter_label)
Esempio n. 21
0
class ModelStage(ModelBase):
    """Train with pixel loss"""
    def __init__(self, opt, stage0=False, stage1=False, stage2=False):
        super(ModelStage, self).__init__(opt)
        # ------------------------------------
        # define network
        # ------------------------------------
        self.stage0 = stage0
        self.stage1 = stage1
        self.stage2 = stage2
        self.netG = define_G(opt, self.stage0, self.stage1,
                             self.stage2).to(self.device)
        self.netG = DataParallel(self.netG)

    """
    # ----------------------------------------
    # Preparation before training with data
    # Save model during training
    # ----------------------------------------
    """

    # ----------------------------------------
    # initialize training
    # ----------------------------------------
    def init_train(self):
        self.opt_train = self.opt['train']  # training option
        self.load()  # load model
        self.netG.train()  # set training mode,for BN
        self.define_loss()  # define loss
        self.define_optimizer()  # define optimizer
        self.define_scheduler()  # define scheduler
        self.log_dict = OrderedDict()  # log

    # ----------------------------------------
    # load pre-trained G model
    # ----------------------------------------
    def load(self):

        if self.stage0:
            load_path_G = self.opt['path']['pretrained_netG0']
        elif self.stage1:
            load_path_G = self.opt['path']['pretrained_netG1']
        elif self.stage2:
            load_path_G = self.opt['path']['pretrained_netG2']
        if load_path_G is not None:
            print('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG)

    # ----------------------------------------
    # save model
    # ----------------------------------------
    def save(self, iter_label):
        if self.stage0:
            self.save_network(self.save_dir, self.netG, 'G0', iter_label)
        elif self.stage1:
            self.save_network(self.save_dir, self.netG, 'G1', iter_label)
        elif self.stage2:
            self.save_network(self.save_dir, self.netG, 'G2', iter_label)

    # ----------------------------------------
    # define loss
    # ----------------------------------------
    def define_loss(self):
        G_lossfn_type = self.opt_train['G_lossfn_type']
        if G_lossfn_type == 'l1':
            self.G_lossfn = nn.L1Loss().to(self.device)
        elif G_lossfn_type == 'l2':
            self.G_lossfn = nn.MSELoss().to(self.device)
        elif G_lossfn_type == 'l2sum':
            self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
        elif G_lossfn_type == 'ssim':
            self.G_lossfn = SSIMLoss().to(self.device)
        else:
            raise NotImplementedError(
                'Loss type [{:s}] is not found.'.format(G_lossfn_type))
        self.G_lossfn_weight = self.opt_train['G_lossfn_weight']

    # ----------------------------------------
    # define optimizer
    # ----------------------------------------
    def define_optimizer(self):
        G_optim_params = []
        for k, v in self.netG.named_parameters():
            if v.requires_grad:
                G_optim_params.append(v)
            else:
                print('Params [{:s}] will not optimize.'.format(k))
        self.G_optimizer = Adam(G_optim_params,
                                lr=self.opt_train['G_optimizer_lr'],
                                weight_decay=0)

    # ----------------------------------------
    # define scheduler, only "MultiStepLR"
    # ----------------------------------------
    def define_scheduler(self):
        self.schedulers.append(
            lr_scheduler.MultiStepLR(self.G_optimizer,
                                     self.opt_train['G_scheduler_milestones'],
                                     self.opt_train['G_scheduler_gamma']))

    """
    # ----------------------------------------
    # Optimization during training with data
    # Testing/evaluation
    # ----------------------------------------
    """

    # ----------------------------------------
    # feed L/H data
    # ----------------------------------------
    def feed_data(self, data):
        if self.stage0:
            Ls = data['ls']
            self.Ls = util.tos(*Ls, device=self.device)
            Hs = data['hs']
            self.Hs = util.tos(*Hs, device=self.device)
        if self.stage1:
            self.L0 = data['L0'].to(self.device)
            self.H = data['H'].to(self.device)
        elif self.stage2:
            Ls = data['L']
            self.Ls = util.tos(*Ls, device=self.device)
            self.H = data['H'].to(self.device)  #hide for test

    # ----------------------------------------
    # update parameters and get loss
    # ----------------------------------------
    def optimize_parameters(self, current_step):

        self.G_optimizer.zero_grad()

        if self.stage0:
            self.Es = self.netG(self.Ls)
            _loss = []
            for (Es_i, Hs_i) in zip(self.Es, self.Hs):
                _loss += [self.G_lossfn(Es_i, Hs_i)]
            G_loss = sum(_loss) * self.G_lossfn_weight

        if self.stage1:
            self.E = self.netG(self.L0)
            G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H)

        if self.stage2:
            self.E = self.netG(self.Ls)
            G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H)

        G_loss.backward()

        # ------------------------------------
        # clip_grad
        # ------------------------------------
        # `clip_grad_norm` helps prevent the exploding gradient problem.
        G_optimizer_clipgrad = self.opt_train[
            'G_optimizer_clipgrad'] if self.opt_train[
                'G_optimizer_clipgrad'] else 0
        if G_optimizer_clipgrad > 0:
            torch.nn.utils.clip_grad_norm_(
                self.parameters(),
                max_norm=self.opt_train['G_optimizer_clipgrad'],
                norm_type=2)

        self.G_optimizer.step()

        # ------------------------------------
        # regularizer
        # ------------------------------------
        G_regularizer_orthstep = self.opt_train[
            'G_regularizer_orthstep'] if self.opt_train[
                'G_regularizer_orthstep'] else 0
        if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % \
                self.opt['train']['checkpoint_save'] != 0:
            self.netG.apply(regularizer_orth)
        G_regularizer_clipstep = self.opt_train[
            'G_regularizer_clipstep'] if self.opt_train[
                'G_regularizer_clipstep'] else 0
        if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % \
                self.opt['train']['checkpoint_save'] != 0:
            self.netG.apply(regularizer_clip)

        # self.log_dict['G_loss'] = G_loss.item()/self.E.size()[0]  # if `reduction='sum'`
        self.log_dict['G_loss'] = G_loss.item()

    # ----------------------------------------
    # test / inference
    # ----------------------------------------
    def test(self):
        self.netG.eval()
        if self.stage0:
            with torch.no_grad():
                self.Es = self.netG(self.Ls)
        elif self.stage1:
            with torch.no_grad():
                self.E = self.netG(self.L0)
        elif self.stage2:
            with torch.no_grad():
                self.E = self.netG(self.Ls)
        self.netG.train()

    # ----------------------------------------
    # get log_dict
    # ----------------------------------------
    def current_log(self):
        return self.log_dict

    # ----------------------------------------
    # get L, E, H image
    # ----------------------------------------
    def current_visuals(self):
        out_dict = OrderedDict()
        if self.stage0:
            out_dict['L'] = self.Ls[0].detach()[0].float().cpu()
            out_dict['Es0'] = self.Es[0].detach()[0].float().cpu()
            out_dict['Hs0'] = self.Hs[0].detach()[0].float().cpu()
        elif self.stage1:
            out_dict['L'] = self.L0.detach()[0].float().cpu()
            out_dict['E'] = self.E.detach()[0].float().cpu()
            out_dict['H'] = self.H.detach()[0].float().cpu()  #hide for test

        elif self.stage2:
            out_dict['L'] = self.Ls[0].detach()[0].float().cpu()
            out_dict['E'] = self.E.detach()[0].float().cpu()
            out_dict['H'] = self.H.detach()[0].float().cpu()  #hide for test
        return out_dict

    """
    # ----------------------------------------
    # Information of netG
    # ----------------------------------------
    """

    # ----------------------------------------
    # print network
    # ----------------------------------------
    def print_network(self):
        msg = self.describe_network(self.netG)
        print(msg)

    # ----------------------------------------
    # print params
    # ----------------------------------------
    def print_params(self):
        msg = self.describe_params(self.netG)
        print(msg)

    # ----------------------------------------
    # network information
    # ----------------------------------------
    def info_network(self):
        msg = self.describe_network(self.netG)
        return msg

    # ----------------------------------------
    # params information
    # ----------------------------------------
    def info_params(self):
        msg = self.describe_params(self.netG)
        return msg
Esempio n. 22
0
class ESRGAN_EESN_FRCNN_Model(BaseModel):
    def __init__(self, config, device):
        super(ESRGAN_EESN_FRCNN_Model, self).__init__(config, device)
        self.configG = config['network_G']
        self.configD = config['network_D']
        self.configT = config['train']
        self.configO = config['optimizer']['args']
        self.configS = config['lr_scheduler']
        self.config = config
        self.device = device
        #Generator
        self.netG = model.ESRGAN_EESN(in_nc=self.configG['in_nc'],
                                      out_nc=self.configG['out_nc'],
                                      nf=self.configG['nf'],
                                      nb=self.configG['nb'])
        self.netG = self.netG.to(self.device)
        self.netG = DataParallel(self.netG)

        #descriminator
        self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'],
                                                nf=self.configD['nf'])
        self.netD = self.netD.to(self.device)
        self.netD = DataParallel(self.netD)

        #FRCNN_model
        self.netFRCNN = torchvision.models.detection.fasterrcnn_resnet50_fpn(
            pretrained=True)
        num_classes = 2  # car and background
        in_features = self.netFRCNN.roi_heads.box_predictor.cls_score.in_features
        self.netFRCNN.roi_heads.box_predictor = FastRCNNPredictor(
            in_features, num_classes)
        self.netFRCNN.to(self.device)

        self.netG.train()
        self.netD.train()
        self.netFRCNN.train()
        #print(self.configT['pixel_weight'])
        # G CharbonnierLoss for final output SR and GT HR
        self.cri_charbonnier = CharbonnierLoss().to(device)
        # G pixel loss
        if self.configT['pixel_weight'] > 0.0:
            l_pix_type = self.configT['pixel_criterion']
            if l_pix_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif l_pix_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_pix_type))
            self.l_pix_w = self.configT['pixel_weight']
        else:
            self.cri_pix = None

        # G feature loss
        #print(self.configT['feature_weight']+1)
        if self.configT['feature_weight'] > 0:
            l_fea_type = self.configT['feature_criterion']
            if l_fea_type == 'l1':
                self.cri_fea = nn.L1Loss().to(self.device)
            elif l_fea_type == 'l2':
                self.cri_fea = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_fea_type))
            self.l_fea_w = self.configT['feature_weight']
        else:
            self.cri_fea = None
        if self.cri_fea:  # load VGG perceptual loss
            self.netF = model.VGGFeatureExtractor(feature_layer=34,
                                                  use_input_norm=True,
                                                  device=self.device)
            self.netF = self.netF.to(self.device)
            self.netF = DataParallel(self.netF)
            self.netF.eval()

        # GD gan loss
        self.cri_gan = GANLoss(self.configT['gan_type'], 1.0,
                               0.0).to(self.device)
        self.l_gan_w = self.configT['gan_weight']
        # D_update_ratio and D_init_iters
        self.D_update_ratio = self.configT['D_update_ratio'] if self.configT[
            'D_update_ratio'] else 1
        self.D_init_iters = self.configT['D_init_iters'] if self.configT[
            'D_init_iters'] else 0

        # optimizers
        # G
        wd_G = self.configO['weight_decay_G'] if self.configO[
            'weight_decay_G'] else 0
        optim_params = []
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            if v.requires_grad:
                optim_params.append(v)

        self.optimizer_G = torch.optim.Adam(optim_params,
                                            lr=self.configO['lr_G'],
                                            weight_decay=wd_G,
                                            betas=(self.configO['beta1_G'],
                                                   self.configO['beta2_G']))
        self.optimizers.append(self.optimizer_G)

        # D
        wd_D = self.configO['weight_decay_D'] if self.configO[
            'weight_decay_D'] else 0
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=self.configO['lr_D'],
                                            weight_decay=wd_D,
                                            betas=(self.configO['beta1_D'],
                                                   self.configO['beta2_D']))
        self.optimizers.append(self.optimizer_D)

        # FRCNN -- use weigt decay
        FRCNN_params = [
            p for p in self.netFRCNN.parameters() if p.requires_grad
        ]
        self.optimizer_FRCNN = torch.optim.SGD(FRCNN_params,
                                               lr=0.005,
                                               momentum=0.9,
                                               weight_decay=0.0005)
        self.optimizers.append(self.optimizer_FRCNN)

        # schedulers
        if self.configS['type'] == 'MultiStepLR':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.MultiStepLR_Restart(
                        optimizer,
                        self.configS['args']['lr_steps'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights'],
                        gamma=self.configS['args']['lr_gamma'],
                        clear_state=False))
        elif self.configS['type'] == 'CosineAnnealingLR_Restart':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.CosineAnnealingLR_Restart(
                        optimizer,
                        self.configS['args']['T_period'],
                        eta_min=self.configS['args']['eta_min'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights']))
        else:
            raise NotImplementedError(
                'MultiStepLR learning rate scheme is enough.')
        print(self.configS['args']['restarts'])
        self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    '''
    The main repo did not use collate_fn and image read has different flags
    and also used np.ascontiguousarray()
    Might change my code if problem happens
    '''

    def feed_data(self, image, targets):
        self.var_L = image['image_lq'].to(self.device)
        self.var_H = image['image'].to(self.device)
        input_ref = image['ref'] if 'ref' in image else image['image']
        self.var_ref = input_ref.to(self.device)
        '''
        for t in targets:
            for k, v in t.items():
                print(v)
        '''
        self.targets = [{k: v.to(self.device)
                         for k, v in t.items()} for t in targets]

    def optimize_parameters(self, step):
        #Generator
        for p in self.netG.parameters():
            p.requires_grad = True
        for p in self.netD.parameters():
            p.requires_grad = False
        self.optimizer_G.zero_grad()
        self.fake_H, self.final_SR, self.x_learned_lap_fake, _ = self.netG(
            self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  #pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach(
                )  #don't want to backpropagate this, need proper explanation
                fake_fea = self.netF(
                    self.fake_H)  #In netF normalize=False, check it
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            pred_g_fake = self.netD(self.fake_H)
            if self.configT['gan_type'] == 'gan':
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.configT['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan
            #EESN calculate loss
            self.lap_HR = kornia.laplacian(self.var_H, 3)
            if self.cri_charbonnier:  # charbonnier pixel loss HR and SR
                l_e_charbonnier = 5 * (
                    self.cri_charbonnier(self.final_SR, self.var_H) +
                    self.cri_charbonnier(self.x_learned_lap_fake, self.lap_HR)
                )  #change the weight to empirically
            l_g_total += l_e_charbonnier

            l_g_total.backward(retain_graph=True)
            # self.optimizer_G.step()

        #descriminator
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_ref)
        pred_d_fake = self.netD(
            self.fake_H.detach())  #to avoid BP to Generator
        if self.configT['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.configT['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real +
                         l_d_fake) / 2  # thinking of adding final sr d loss

        l_d_total.backward()
        self.optimizer_D.step()
        '''
        Freeze EESRGAN
        '''
        #freeze Generator
        '''
        for p in self.netG.parameters():
            p.requires_grad = False
        '''
        for p in self.netD.parameters():
            p.requires_grad = False
        #Run FRCNN
        self.optimizer_FRCNN.zero_grad()
        self.intermediate_img = self.final_SR
        img_count = self.intermediate_img.size()[0]
        self.intermediate_img = [
            self.intermediate_img[i] for i in range(img_count)
        ]
        loss_dict = self.netFRCNN(self.intermediate_img, self.targets)
        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        losses.backward()
        self.optimizer_G.step()
        self.optimizer_FRCNN.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()
            self.log_dict['l_e_charbonnier'] = l_e_charbonnier.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
        self.log_dict['FRCNN_loss'] = loss_value

    def test(self, valid_data_loader, train=True, testResult=False):
        self.netG.eval()
        self.netFRCNN.eval()
        self.targets = valid_data_loader
        if testResult == False:
            with torch.no_grad():
                self.fake_H, self.final_SR, self.x_learned_lap_fake, self.x_lap = self.netG(
                    self.var_L)
                self.x_lap_HR = kornia.laplacian(self.var_H, 3)
        if train == True:
            evaluate(self.netG, self.netFRCNN, self.targets, self.device)
        if testResult == True:
            evaluate(self.netG, self.netFRCNN, self.targets, self.device)
            evaluate_save(self.netG, self.netFRCNN, self.targets, self.device,
                          self.config)
        self.netG.train()
        self.netFRCNN.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        #out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['lap_learned'] = self.x_learned_lap_fake.detach()[0].float(
        ).cpu()
        out_dict['lap_HR'] = self.x_lap_HR.detach()[0].float().cpu()
        out_dict['lap'] = self.x_lap.detach()[0].float().cpu()
        out_dict['final_SR'] = self.final_SR.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)

        logger.info('Network G structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

        # Discriminator
        s, n = self.get_network_description(self.netD)
        if isinstance(self.netD, nn.DataParallel) or isinstance(
                self.netD, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netD.__class__.__name__,
                self.netD.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netD.__class__.__name__)

        logger.info('Network D structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

        if self.cri_fea:  # F, Perceptual Network
            s, n = self.get_network_description(self.netF)
            if isinstance(self.netF, nn.DataParallel) or isinstance(
                    self.netF, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netF.__class__.__name__,
                    self.netF.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netF.__class__.__name__)

            logger.info(
                'Network F structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

        #FRCNN_model
        # Discriminator
        s, n = self.get_network_description(self.netFRCNN)
        if isinstance(self.netFRCNN, nn.DataParallel) or isinstance(
                self.netFRCNN, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netFRCNN.__class__.__name__,
                self.netFRCNN.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netFRCNN.__class__.__name__)

        logger.info(
            'Network FRCNN structure: {}, with parameters: {:,d}'.format(
                net_struc_str, n))
        logger.info(s)

    def load(self):
        load_path_G = self.config['path']['pretrain_model_G']
        if load_path_G:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.config['path']['strict_load'])
        load_path_D = self.config['path']['pretrain_model_D']
        if load_path_D:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.config['path']['strict_load'])
        load_path_FRCNN = self.config['path']['pretrain_model_FRCNN']
        if load_path_FRCNN:
            logger.info(
                'Loading model for D [{:s}] ...'.format(load_path_FRCNN))
            self.load_network(load_path_FRCNN, self.netFRCNN,
                              self.config['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
        self.save_network(self.netFRCNN, 'FRCNN', iter_step)
class IRNpModel(BaseModel):
    def __init__(self, opt):
        super(IRNpModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        test_opt = opt['test']
        self.train_opt = train_opt
        self.test_opt = test_opt

        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        self.Quantization = Quantization()

        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

            # loss
            self.Reconstruction_forw = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_forw'])
            self.Reconstruction_back = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_back'])

            # feature loss
            if train_opt['feature_weight'] > 0:
                self.Reconstructionf = ReconstructionLoss(
                    losstype=self.train_opt['feature_criterion'])

                self.l_fea_w = train_opt['feature_weight']
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)
            else:
                self.l_fea_w = 0

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']

            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data):
        self.ref_L = data['LQ'].to(self.device)  # LQ
        self.real_H = data['GT'].to(self.device)  # GT

    def gaussian_batch(self, dims):
        return torch.randn(tuple(dims)).to(self.device)

    def loss_forward(self, out, y):
        l_forw_fit = self.train_opt[
            'lambda_fit_forw'] * self.Reconstruction_forw(out[:, :3, :, :], y)

        return l_forw_fit

    def loss_backward(self, x, x_samples):
        x_samples_image = x_samples[:, :3, :, :]
        l_back_rec = self.train_opt[
            'lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image)

        # feature loss
        if self.l_fea_w > 0:
            l_back_fea = self.feature_loss(x, x_samples_image)
        else:
            l_back_fea = torch.tensor(0)

        # GAN loss
        pred_g_fake = self.netD(x_samples_image)
        if self.opt['train']['gan_type'] == 'gan':
            l_back_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
        elif self.opt['train']['gan_type'] == 'ragan':
            pred_d_real = self.netD(x).detach()
            l_back_gan = self.l_gan_w * (
                self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
                self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2

        return l_back_rec, l_back_fea, l_back_gan

    def feature_loss(self, real, fake):
        real_fea = self.netF(real).detach()
        fake_fea = self.netF(fake)
        l_g_fea = self.l_fea_w * self.Reconstructionf(real_fea, fake_fea)

        return l_g_fea

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()

        print('input shape: ', self.input.shape)
        self.input = self.real_H
        self.output = self.netG(x=self.input)
        print('output shape: ', self.output.shape)

        loss = 0
        zshape = self.output[:, 3:, :, :].shape
        print('z shape: ', zshape)

        LR = self.Quantization(self.output[:, :3, :, :])

        gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt[
            'gaussian_scale'] != None else 1
        y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)),
                       dim=1)
        print('y_ shape: ', y_.shape)

        self.fake_H = self.netG(x=y_, rev=True)
        print('fake_H shape: ', self.fake_H.shape)

        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            l_forw_fit = self.loss_forward(self.output, self.ref_L)
            l_back_rec, l_back_fea, l_back_gan = self.loss_backward(
                self.real_H, self.fake_H)

            loss += l_forw_fit + l_back_rec + l_back_fea + l_back_gan

            loss.backward()

            # gradient clipping
            if self.train_opt['gradient_clipping']:
                nn.utils.clip_grad_norm_(self.netG.parameters(),
                                         self.train_opt['gradient_clipping'])

            self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.real_H)
        pred_d_fake = self.netD(self.fake_H.detach())
        if self.opt['train']['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.opt['train']['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real + l_d_fake) / 2

        l_d_total.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            self.log_dict['l_forw_fit'] = l_forw_fit.item()
            self.log_dict['l_back_rec'] = l_back_rec.item()
            self.log_dict['l_back_fea'] = l_back_fea.item()
            self.log_dict['l_back_gan'] = l_back_gan.item()
        self.log_dict['l_d'] = l_d_total.item()

    def test(self):
        Lshape = self.ref_L.shape

        input_dim = Lshape[1]
        self.input = self.real_H

        print('test mode==>input shape: ', self.input.shape)
        zshape = [
            Lshape[0], input_dim * (self.opt['scale']**2) - Lshape[1],
            Lshape[2], Lshape[3]
        ]
        print('test mode==>zshape: ', zshape)

        gaussian_scale = 1
        if self.test_opt and self.test_opt['gaussian_scale'] != None:
            gaussian_scale = self.test_opt['gaussian_scale']

        self.netG.eval()
        with torch.no_grad():
            self.forw_L = self.netG(x=self.input)[:, :3, :, :]
            self.forw_L = self.Quantization(self.forw_L)
            print('test mode==>forw_L shape: ', self.forw_L.shape)
            y_forw = torch.cat(
                (self.forw_L, gaussian_scale * self.gaussian_batch(zshape)),
                dim=1)
            print('test mode==>y_forw shape: ', y_forw.shape)
            self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :]
            print('test mode==>fake_H shape: ', y_forw.shape)

        self.netG.train()

    def downscale(self, HR_img):
        self.netG.eval()
        with torch.no_grad():
            LR_img = self.netG(x=HR_img)[:, :3, :, :]
            LR_img = self.Quantization(self.forw_L)
        self.netG.train()

        return LR_img

    def upscale(self, LR_img, scale, gaussian_scale=1):
        Lshape = LR_img.shape
        zshape = [Lshape[0], Lshape[1] * (scale**2 - 1), Lshape[2], Lshape[3]]
        y_ = torch.cat((LR_img, gaussian_scale * self.gaussian_batch(zshape)),
                       dim=1)

        self.netG.eval()
        with torch.no_grad():
            HR_img = self.netG(x=y_, rev=True)[:, :3, :, :]
        self.netG.train()

        return HR_img

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self):
        out_dict = OrderedDict()
        out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['LR'] = self.forw_L.detach()[0].float().cpu()
        out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

        load_path_D = self.opt['path']['pretrain_model_D']
        if load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
        self.save_network(self.netD, 'D', iter_label)
Esempio n. 24
0
class P_Model(BaseModel):
    def __init__(self, opt):
        super(P_Model, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            #self.init_model()
            self.netG.train()

            # loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            #self.optimizer_G = torch.optim.SGD(optim_params, lr=train_opt['lr_G'], momentum=0.9)
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                print('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def init_model(self, scale=0.1):
        # Common practise for initialization.
        for layer in self.netG.modules():
            if isinstance(layer, nn.Conv2d):
                init.kaiming_normal_(layer.weight, a=0, mode='fan_in')
                layer.weight.data *= scale  # for residual block
                if layer.bias is not None:
                    layer.bias.data.zero_()
            elif isinstance(layer, nn.Linear):
                init.kaiming_normal_(layer.weight, a=0, mode='fan_in')
                layer.weight.data *= scale
                if layer.bias is not None:
                    layer.bias.data.zero_()
            elif isinstance(layer, nn.BatchNorm2d):
                init.constant_(layer.weight, 1)
                init.constant_(layer.bias.data, 0.0)

    def feed_data(self, lr_img, ker_map):
        self.var_L = lr_img.to(self.device)  # LQ
        self.real_ker = ker_map.to(self.device)  # real kernel map
        # self.var_L = data['LQ'].to(self.device)
        # self.real_ker = data['real_ker'].to(self.device)

    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()
        self.fake_ker = self.netG(self.var_L)
        l_pix = self.l_pix_w * self.cri_pix(self.fake_ker, self.real_ker)
        l_pix.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_ker = self.netG(self.var_L)
        self.netG.train()

    def test_x8(self):
        # from https://github.com/thstkdgus35/EDSR-PyTorch
        self.netG.eval()

        def _transform(v, op):
            # if self.precision != 'single': v = v.float()
            v2np = v.data.cpu().numpy()
            if op == 'v':
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == 'h':
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == 't':
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = torch.Tensor(tfnp).to(self.device)
            # if self.precision == 'half': ret = ret.half()

            return ret

        lr_list = [self.var_L]
        for tf in 'v', 'h', 't':
            lr_list.extend([_transform(t, tf) for t in lr_list])
        with torch.no_grad():
            sr_list = [self.netG(aug) for aug in lr_list]
        for i in range(len(sr_list)):
            if i > 3:
                sr_list[i] = _transform(sr_list[i], 't')
            if i % 4 > 1:
                sr_list[i] = _transform(sr_list[i], 'h')
            if (i % 4) % 2 == 1:
                sr_list[i] = _transform(sr_list[i], 'v')

        output_cat = torch.cat(sr_list, dim=0)
        self.fake_H = output_cat.mean(dim=0, keepdim=True)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self):
        out_dict = OrderedDict()
        out_dict['est_ker_map'] = self.fake_ker.detach()[0].float().cpu(
        )  # for validation
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['Batch_est_ker_map'] = self.fake_ker.detach().float().cpu(
        )  # Batch est_ker_map, for train
        out_dict['Batch_LQ'] = self.var_L.detach().float().cpu()
        #out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        #out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
Esempio n. 25
0
class SRModel(BaseModel):
    def __init__(self, opt):
        super(SRModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            # loss
            loss_type = train_opt['pixel_criterion']
            self.loss_type = loss_type
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.real_H = data['GT'].to(self.device)  # GT

    def mixup_data(self, x, y, alpha=1.0, use_cuda=True):
        '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
        batch_size = x.size()[0]
        lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
        index = torch.randperm(
            batch_size).cuda() if use_cuda else torch.randperm(batch_size)
        mixed_x = lam * x + (1 - lam) * x[index, :]
        mixed_y = lam * y + (1 - lam) * y[index, :]
        return mixed_x, mixed_y

    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()
        '''add mixup operation'''
        #         self.var_L, self.real_H = self.mixup_data(self.var_L, self.real_H)

        self.fake_H = self.netG(self.var_L)
        if self.loss_type == 'fs':
            l_pix = self.l_pix_w * self.cri_pix(
                self.fake_H, self.real_H) + self.l_fs_w * self.cri_fs(
                    self.fake_H, self.real_H)
        elif self.loss_type == 'grad':
            l1 = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
            lg = self.l_grad_w * self.gradloss(self.fake_H, self.real_H)
            l_pix = l1 + lg
        elif self.loss_type == 'grad_fs':
            l1 = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
            lg = self.l_grad_w * self.gradloss(self.fake_H, self.real_H)
            lfs = self.l_fs_w * self.cri_fs(self.fake_H, self.real_H)
            l_pix = l1 + lg + lfs
        else:
            l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        l_pix.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()
        if self.loss_type == 'grad':
            self.log_dict['l_1'] = l1.item()
            self.log_dict['l_grad'] = lg.item()
        if self.loss_type == 'grad_fs':
            self.log_dict['l_1'] = l1.item()
            self.log_dict['l_grad'] = lg.item()
            self.log_dict['l_fs'] = lfs.item()

    def test(self):
        self.netG.eval()

        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def test_x8(self):
        # from https://github.com/thstkdgus35/EDSR-PyTorch
        self.netG.eval()

        def _transform(v, op):
            # if self.precision != 'single': v = v.float()
            v2np = v.data.cpu().numpy()
            if op == 'v':
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == 'h':
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == 't':
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = torch.Tensor(tfnp).to(self.device)
            # if self.precision == 'half': ret = ret.half()

            return ret

        lr_list = [self.var_L]
        for tf in 'v', 'h', 't':
            lr_list.extend([_transform(t, tf) for t in lr_list])
        with torch.no_grad():
            sr_list = [self.netG(aug) for aug in lr_list]
        for i in range(len(sr_list)):
            if i > 3:
                sr_list[i] = _transform(sr_list[i], 't')
            if i % 4 > 1:
                sr_list[i] = _transform(sr_list[i], 'h')
            if (i % 4) % 2 == 1:
                sr_list[i] = _transform(sr_list[i], 'v')

        output_cat = torch.cat(sr_list, dim=0)
        self.fake_H = output_cat.mean(dim=0, keepdim=True)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

#     def load(self):
#         load_path_G_1 = self.opt['path']['pretrain_model_G_1']
#         load_path_G_2 = self.opt['path']['pretrain_model_G_2']
#         load_path_Gs=[load_path_G_1, load_path_G_2]

#         load_path_G = self.opt['path']['pretrain_model_G']
#         if load_path_G is not None:
#             logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
#             self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
#         if load_path_G_1 is not None:
#             logger.info('Loading model for 3net [{:s}] ...'.format(load_path_G_1))
#             logger.info('Loading model for 3net [{:s}] ...'.format(load_path_G_2))
#             self.load_network_part(load_path_Gs, self.netG, self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
Esempio n. 26
0
class SRDCTModel(BaseModel):
    def __init__(self, opt):
        super(SRDCTModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            # loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.real_H = data['GT'].to(self.device)  # GT

    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)

        # ========================= add by Nan ========================= #
        # Otthogonality Constraint
        self.dct_weight = self.netG.module.get_dct_weight()
        self.dct_weight = self.dct_weight.reshape(64, -1)
        eye = torch.eye(64).to(self.device)
        self.ortho_constraint = 0.5 * F.mse_loss(
            torch.matmul(self.dct_weight, self.dct_weight.T), eye, True)

        # Complexity Order Constraint
        self.complex_order_constraint = 0.0
        DCT_weight = self.netG.module.get_dct_weight()
        DCT_basis = self.netG.module.get_DCT_2D_Basis().to(self.device)
        for i in range(DCT_weight.shape[0]):
            basis_item = DCT_basis[i]
            weight_item = DCT_weight[i]
            var_loss = self.cri_pix(torch.var(basis_item),
                                    torch.var(weight_item))
            self.complex_order_constraint = self.complex_order_constraint + var_loss

        l_pix = self.l_pix_w * self.cri_pix(
            self.fake_H, self.real_H
        ) + 3.5 * self.ortho_constraint + 0.75 * self.complex_order_constraint
        l_pix.backward(retain_graph=True)
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def test_x8(self):
        # from https://github.com/thstkdgus35/EDSR-PyTorch
        self.netG.eval()

        def _transform(v, op):
            # if self.precision != 'single': v = v.float()
            v2np = v.data.cpu().numpy()
            if op == 'v':
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == 'h':
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == 't':
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = torch.Tensor(tfnp).to(self.device)
            # if self.precision == 'half': ret = ret.half()

            return ret

        lr_list = [self.var_L]
        for tf in 'v', 'h', 't':
            lr_list.extend([_transform(t, tf) for t in lr_list])
        with torch.no_grad():
            sr_list = [self.netG(aug) for aug in lr_list]
        for i in range(len(sr_list)):
            if i > 3:
                sr_list[i] = _transform(sr_list[i], 't')
            if i % 4 > 1:
                sr_list[i] = _transform(sr_list[i], 'h')
            if (i % 4) % 2 == 1:
                sr_list[i] = _transform(sr_list[i], 'v')

        output_cat = torch.cat(sr_list, dim=0)
        self.fake_H = output_cat.mean(dim=0, keepdim=True)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
class GenerativeModel(BaseModel):
    def __init__(self, opt):
        super(GenerativeModel, self).__init__(opt)

        # DISTRIBUTED TRAINING OR NOT
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1

        # DEFINE NETWORKS
        self.netE = networks.define_encoder(opt).to(self.device)
        self.netD = networks.define_decoder(opt).to(self.device)
        self.netF, self.nz, self.stop_gradients = networks.define_flow(opt)
        self.netF.to(self.device)
        if opt['dist']:
            self.netE = DistributedDataParallel(self.netE, device_ids=[torch.cuda.current_device()])
            self.netD = DistributedDataParallel(self.netD, device_ids=[torch.cuda.current_device()])
            self.netF = DistributedDataParallel(self.netF, device_ids=[torch.cuda.current_device()])
        else:
            self.netE = DataParallel(self.netE)
            self.netD = DataParallel(self.netD)
            self.netF = DataParallel(self.netF)

        if self.is_train:
            self.netE.train()
            self.netD.train()
            self.netF.train()

        # GET CONFIG PARAMS FOR LOSSES AND LR
        train_opt = opt['train']

        # DEFINE LOSSES, OPTIMIZER AND SCHEDULE
        if self.is_train:
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss(reduction='mean').to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss(reduction='mean').to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']

                if train_opt['add_background_mask']:
                    self.add_mask = True
                else:
                    self.add_mask = False

            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            if train_opt['nll_weight'] is None:
                raise ValueError('nll loss should be always in this version')
            self.cri_nll = NLLLoss(reduction='mean').to(self.device)
            self.l_nll_w = train_opt['nll_weight']

            if train_opt['feature_weight'] > 0:
                self.cri_fea = VGGLoss().to(self.device)
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None

            # optimizers
            if train_opt['lr_E'] > 0:
                self.optimizer_E = torch.optim.Adam(self.netE.parameters(),
                                                    lr=train_opt['lr_E'],
                                                    weight_decay=train_opt['weight_decay_E'] if train_opt[
                                                        'weight_decay_E'] else 0,
                                                    betas=(train_opt['beta1_E'], train_opt['beta2_E']))
                self.optimizers.append(self.optimizer_E)
            else:
                for p in self.netE.parameters():
                    p.requires_grad_(False)
                logger.info('Freeze encoder.')

            if train_opt['lr_D'] > 0:
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=train_opt['lr_D'],
                                                    weight_decay=train_opt['weight_decay_D'] if train_opt[
                                                        'weight_decay_D'] else 0,
                                                    betas=(train_opt['beta1_D'], train_opt['beta2_D']))
                self.optimizers.append(self.optimizer_D)
            else:
                for p in self.netD.parameters():
                    p.requires_grad_(False)
                logger.info('Freeze decoder.')

            if train_opt['lr_F'] > 0:
                self.optimizer_F = torch.optim.Adam(self.netF.parameters(),
                                                    lr=train_opt['lr_F'],
                                                    weight_decay=train_opt['weight_decay_F'] if train_opt[
                                                        'weight_decay_F'] else 0,
                                                    betas=(train_opt['beta1_F'], train_opt['beta2_F']))
                self.optimizers.append(self.optimizer_F)
            else:
                for p in self.netF.parameters():
                    p.requires_grad_(False)
                logger.info('Freeze flow.')

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
                                                         restarts=train_opt['restarts'],
                                                         weights=train_opt['restart_weights'],
                                                         gamma=train_opt['lr_gamma'],
                                                         clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
            else:
                logger.info('No learning rate scheme is applied.')

            self.log_dict = OrderedDict()

        self.print_network()  # print networks structure
        self.load()  # load G, D, F if needed
        self.test_flow()

    def feed_data(self, data, need_GT=True):
        self.image = data[0].to(self.device)
        if need_GT:
            self.image_gt = self.image

    def optimize_parameters(self, step):
        for optimizer in self.optimizers:
            optimizer.zero_grad()

        z = self.netE(self.image)
        reconstructed = self.netD(z)

        l_total = 0

        if self.cri_pix:  # pixel loss
            if self.add_mask:
                mask = (self.image_gt[:, 0, :, :] == 1).unsqueeze(1).float()
                inv_mask = 1 - mask
                l_pix = (0.2 * self.cri_pix(reconstructed * mask, self.image_gt * mask) +
                         0.8 * self.cri_pix(reconstructed * inv_mask, self.image_gt * inv_mask))
            else:
                l_pix = self.l_pix_w * self.cri_pix(reconstructed, self.image_gt)
            l_total += l_pix

        if self.cri_fea:  # feature loss
            l_fea = self.l_fea_w * self.cri_fea(reconstructed, self.image_gt)
            l_total += l_fea

        # negative likelihood loss
        if self.stop_gradients:
            noise_out, logdets = self.netF(z.detach())
        else:
            noise_out, logdets = self.netF(z)

        l_nll = self.l_nll_w * self.cri_nll(noise_out, logdets)
        l_total += l_nll

        l_total.backward()
        for optimizer in self.optimizers:
            optimizer.step()

        # set log
        if self.cri_pix:
            self.log_dict['l_pix'] = l_pix.item()
        if self.cri_fea:
            self.log_dict['l_fea'] = l_fea.item()
        if self.cri_nll:
            self.log_dict['l_nll'] = l_nll.item()

    def sample_images(self, n=25):
        self.netF.eval()
        self.netD.eval()
        with torch.no_grad():
            noise = torch.randn(n, self.nz).to(self.device)
            if isinstance(self.netF, nn.DataParallel) or isinstance(self.netF, DistributedDataParallel):
                sample = self.netD(self.netF.module.reverse(noise)).detach().float().cpu()
            else:
                sample = self.netD(self.netF.reverse(noise)).detach().float().cpu()
        self.netF.train()
        self.netD.train()
        return sample

    def get_current_log(self):
        return self.log_dict

    def print_network(self):
        for name, net in [('E', self.netE), ('D', self.netD), ('F', self.netF)]:
            s, n = self.get_network_description(net)
            if isinstance(net, nn.DataParallel) or isinstance(net, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(net.__class__.__name__,
                                                 net.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(net.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n))
                logger.info(s)

        if self.is_train and self.cri_fea:
            vgg_net = self.cri_fea.vgg
            s, n = self.get_network_description(vgg_net)
            if isinstance(vgg_net, nn.DataParallel) or isinstance(vgg_net, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(vgg_net.__class__.__name__,
                                                 vgg_net.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(vgg_net.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network VGG structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
                logger.info(s)

    def load(self):
        load_path_E = self.opt['path']['pretrained_encoder']
        if load_path_E is not None:
            logger.info('Loading model for E [{:s}] ...'.format(load_path_E))
            self.load_network(load_path_E, self.netE, self.opt['path']['strict_load'])

        load_path_D = self.opt['path']['pretrained_decoder']
        if load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])

        load_path_F = self.opt['path']['pretrained_flow']
        if load_path_F is not None:
            logger.info('Loading model for F [{:s}] ...'.format(load_path_F))
            self.load_network(load_path_F, self.netF, self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netE, 'E', iter_step)
        self.save_network(self.netD, 'D', iter_step)
        self.save_network(self.netF, 'F', iter_step)
        
        
    def test_flow(self):
        with torch.no_grad():
            test_input = torch.randn((2, self.nz)).to(self.device)
            test_output, _ = self.netF(test_input)
            if isinstance(self.netF, nn.DataParallel) or isinstance(self.netF, DistributedDataParallel):
                test_input2 = self.netF.module.reverse(test_output)
            else:
                test_input2 = self.netF.reverse(test_output)
            assert torch.allclose(test_input, test_input2), 'Flow model is incorrect'
Esempio n. 28
0
class SRGANModel(BaseModel):
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # G Rank-content loss
            if train_opt['R_weight'] > 0:
                self.l_R_w = train_opt['R_weight']  # load rank-content loss
                self.R_bias = train_opt['R_bias']
                self.netR = networks.define_R(opt).to(self.device)
                if opt['dist']:
                    self.netR = DistributedDataParallel(
                        self.netR, device_ids=[torch.cuda.current_device()])
                else:
                    self.netR = DataParallel(self.netR)
            else:
                logger.info('Remove rank-content loss.')

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            pred_g_fake = self.netD(self.fake_H)
            if self.opt['train']['gan_type'] == 'gan':
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.opt['train']['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan

            if self.l_R_w > 0:  # rank-content loss
                l_g_rank = self.netR(self.fake_H)
                l_g_rank = torch.sigmoid(l_g_rank - self.R_bias)
                l_g_rank = torch.sum(l_g_rank)
                l_g_rank = self.l_R_w * l_g_rank
                l_g_total += l_g_rank

            l_g_total.backward()
            self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_ref)
        pred_d_fake = self.netD(
            self.fake_H.detach())  # detach to avoid BP to G
        if self.opt['train']['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.opt['train']['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real + l_d_fake) / 2

        l_d_total.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()
            self.log_dict['l_g_rank'] = l_g_rank.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network F structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

            if self.l_R_w:  # R, Ranker Network
                s, n = self.get_network_description(self.netR)
                if isinstance(self.netR, nn.DataParallel) or isinstance(
                        self.netR, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netR.__class__.__name__,
                        self.netR.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netR.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network Ranker structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])
        load_path_R = self.opt['path']['pretrain_model_R']
        if load_path_R is not None:
            logger.info('Loading model for R [{:s}] ...'.format(load_path_R))
            self.load_network(load_path_R, self.netR,
                              self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
Esempio n. 29
0
class ModelPlain4(ModelBase):
    """Train with pixel loss"""
    def __init__(self, opt):
        super(ModelPlain4, self).__init__(opt)
        # ------------------------------------
        # define network
        # ------------------------------------
        self.netG = define_G(opt).to(self.device)
        self.netG = DataParallel(self.netG)

    """
    # ----------------------------------------
    # Preparation before training with data
    # Save model during training
    # ----------------------------------------
    """

    # ----------------------------------------
    # initialize training
    # ----------------------------------------
    def init_train(self):
        self.opt_train = self.opt['train']    # training option
        self.load()                           # load model
        self.netG.train()                     # set training mode,for BN
        self.define_loss()                    # define loss
        self.define_optimizer()               # define optimizer
        self.define_scheduler()               # define scheduler
        self.log_dict = OrderedDict()         # log

    # ----------------------------------------
    # load pre-trained G model
    # ----------------------------------------
    def load(self):
        load_path_G = self.opt['path']['pretrained_netG']
        if load_path_G is not None:
            print('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG)

    # ----------------------------------------
    # save model
    # ----------------------------------------
    def save(self, iter_label):
        self.save_network(self.save_dir, self.netG, 'G', iter_label)

    # ----------------------------------------
    # define loss
    # ----------------------------------------
    def define_loss(self):
        G_lossfn_type = self.opt_train['G_lossfn_type']
        if G_lossfn_type == 'l1':
            self.G_lossfn = nn.L1Loss().to(self.device)
        elif G_lossfn_type == 'l2':
            self.G_lossfn = nn.MSELoss().to(self.device)
        elif G_lossfn_type == 'l2sum':
            self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
        elif G_lossfn_type == 'ssim':
            self.G_lossfn = SSIMLoss().to(self.device)
        else:
            raise NotImplementedError('Loss type [{:s}] is not found.'.format(G_lossfn_type))
        self.G_lossfn_weight = self.opt_train['G_lossfn_weight']

    # ----------------------------------------
    # define optimizer
    # ----------------------------------------
    def define_optimizer(self):
        G_optim_params = []
        for k, v in self.netG.named_parameters():
            if v.requires_grad:
                G_optim_params.append(v)
            else:
                print('Params [{:s}] will not optimize.'.format(k))
        self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'], weight_decay=0)

    # ----------------------------------------
    # define scheduler, only "MultiStepLR"
    # ----------------------------------------
    def define_scheduler(self):
        self.schedulers.append(lr_scheduler.MultiStepLR(self.G_optimizer,
                                                        self.opt_train['G_scheduler_milestones'],
                                                        self.opt_train['G_scheduler_gamma']
                                                        ))
    """
    # ----------------------------------------
    # Optimization during training with data
    # Testing/evaluation
    # ----------------------------------------
    """

    # ----------------------------------------
    # feed L/H data
    # ----------------------------------------
    def feed_data(self, data, need_H=True):
        self.L = data['L'].to(self.device)  # low-quality image
        self.k = data['k'].to(self.device)  # blur kernel
        self.sf = np.int(data['sf'][0,...].squeeze().cpu().numpy()) # scale factor
        self.sigma = data['sigma'].to(self.device)  # noise level
        if need_H:
            self.H = data['H'].to(self.device)  # H

    # ----------------------------------------
    # update parameters and get loss
    # ----------------------------------------
    def optimize_parameters(self, current_step):
        self.G_optimizer.zero_grad()
        self.E = self.netG(self.L, self.C)
        G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H)
        G_loss.backward()

        # ------------------------------------
        # clip_grad
        # ------------------------------------
        # `clip_grad_norm` helps prevent the exploding gradient problem.
        G_optimizer_clipgrad = self.opt_train['G_optimizer_clipgrad'] if self.opt_train['G_optimizer_clipgrad'] else 0
        if G_optimizer_clipgrad > 0:
            torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=self.opt_train['G_optimizer_clipgrad'], norm_type=2)

        self.G_optimizer.step()

        # ------------------------------------
        # regularizer
        # ------------------------------------
        G_regularizer_orthstep = self.opt_train['G_regularizer_orthstep'] if self.opt_train['G_regularizer_orthstep'] else 0
        if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0:
            self.netG.apply(regularizer_orth)
        G_regularizer_clipstep = self.opt_train['G_regularizer_clipstep'] if self.opt_train['G_regularizer_clipstep'] else 0
        if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0:
            self.netG.apply(regularizer_clip)

        self.log_dict['G_loss'] = G_loss.item()  #/self.E.size()[0]

    # ----------------------------------------
    # test / inference
    # ----------------------------------------
    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.E = self.netG(self.L, self.k, self.sf, self.sigma)
        self.netG.train()

    # ----------------------------------------
    # get log_dict
    # ----------------------------------------
    def current_log(self):
        return self.log_dict

    # ----------------------------------------
    # get L, E, H image
    # ----------------------------------------
    def current_visuals(self, need_H=True):
        out_dict = OrderedDict()
        out_dict['L'] = self.L.detach()[0].float().cpu()
        out_dict['E'] = self.E.detach()[0].float().cpu()
        if need_H:
            out_dict['H'] = self.H.detach()[0].float().cpu()
        return out_dict

    # ----------------------------------------
    # get L, E, H batch images
    # ----------------------------------------
    def current_results(self, need_H=True):
        out_dict = OrderedDict()
        out_dict['L'] = self.L.detach().float().cpu()
        out_dict['E'] = self.E.detach().float().cpu()
        if need_H:
            out_dict['H'] = self.H.detach().float().cpu()
        return out_dict

    """
    # ----------------------------------------
    # Information of netG
    # ----------------------------------------
    """

    # ----------------------------------------
    # print network
    # ----------------------------------------
    def print_network(self):
        msg = self.describe_network(self.netG)
        print(msg)

    # ----------------------------------------
    # print params
    # ----------------------------------------
    def print_params(self):
        msg = self.describe_params(self.netG)
        print(msg)

    # ----------------------------------------
    # network information
    # ----------------------------------------
    def info_network(self):
        msg = self.describe_network(self.netG)
        return msg

    # ----------------------------------------
    # params information
    # ----------------------------------------
    def info_params(self):
        msg = self.describe_params(self.netG)
        return msg
Esempio n. 30
0
class ESRGAN_EESN_Model(BaseModel):
    def __init__(self, config, device):
        super(ESRGAN_EESN_Model, self).__init__(config, device)
        self.configG = config['network_G']
        self.configD = config['network_D']
        self.configT = config['train']
        self.configO = config['optimizer']['args']
        self.configS = config['lr_scheduler']
        self.device = device
        #Generator
        self.netG = model.ESRGAN_EESN(in_nc=self.configG['in_nc'],
                                      out_nc=self.configG['out_nc'],
                                      nf=self.configG['nf'],
                                      nb=self.configG['nb'])
        self.netG = self.netG.to(self.device)
        self.netG = DataParallel(self.netG, device_ids=[1, 0])

        #descriminator
        self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'],
                                                nf=self.configD['nf'])
        self.netD = self.netD.to(self.device)
        self.netD = DataParallel(self.netD, device_ids=[1, 0])

        self.netG.train()
        self.netD.train()
        #print(self.configT['pixel_weight'])
        # G CharbonnierLoss for final output SR and GT HR
        self.cri_charbonnier = CharbonnierLoss().to(device)
        # G pixel loss
        if self.configT['pixel_weight'] > 0.0:
            l_pix_type = self.configT['pixel_criterion']
            if l_pix_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif l_pix_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_pix_type))
            self.l_pix_w = self.configT['pixel_weight']
        else:
            self.cri_pix = None

        # G feature loss
        #print(self.configT['feature_weight']+1)
        if self.configT['feature_weight'] > 0:
            l_fea_type = self.configT['feature_criterion']
            if l_fea_type == 'l1':
                self.cri_fea = nn.L1Loss().to(self.device)
            elif l_fea_type == 'l2':
                self.cri_fea = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_fea_type))
            self.l_fea_w = self.configT['feature_weight']
        else:
            self.cri_fea = None
        if self.cri_fea:  # load VGG perceptual loss
            self.netF = model.VGGFeatureExtractor(feature_layer=34,
                                                  use_input_norm=True,
                                                  device=self.device)
            self.netF = self.netF.to(self.device)
            self.netF = DataParallel(self.netF, device_ids=[1, 0])
            self.netF.eval()

        # GD gan loss
        self.cri_gan = GANLoss(self.configT['gan_type'], 1.0,
                               0.0).to(self.device)
        self.l_gan_w = self.configT['gan_weight']
        # D_update_ratio and D_init_iters
        self.D_update_ratio = self.configT['D_update_ratio'] if self.configT[
            'D_update_ratio'] else 1
        self.D_init_iters = self.configT['D_init_iters'] if self.configT[
            'D_init_iters'] else 0

        # optimizers
        # G
        wd_G = self.configO['weight_decay_G'] if self.configO[
            'weight_decay_G'] else 0
        optim_params = []
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            if v.requires_grad:
                optim_params.append(v)

        self.optimizer_G = torch.optim.Adam(optim_params,
                                            lr=self.configO['lr_G'],
                                            weight_decay=wd_G,
                                            betas=(self.configO['beta1_G'],
                                                   self.configO['beta2_G']))
        self.optimizers.append(self.optimizer_G)

        # D
        wd_D = self.configO['weight_decay_D'] if self.configO[
            'weight_decay_D'] else 0
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=self.configO['lr_D'],
                                            weight_decay=wd_D,
                                            betas=(self.configO['beta1_D'],
                                                   self.configO['beta2_D']))
        self.optimizers.append(self.optimizer_D)

        # schedulers
        if self.configS['type'] == 'MultiStepLR':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.MultiStepLR_Restart(
                        optimizer,
                        self.configS['args']['lr_steps'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights'],
                        gamma=self.configS['args']['lr_gamma'],
                        clear_state=False))
        elif self.configS['type'] == 'CosineAnnealingLR_Restart':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.CosineAnnealingLR_Restart(
                        optimizer,
                        self.configS['args']['T_period'],
                        eta_min=self.configS['args']['eta_min'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights']))
        else:
            raise NotImplementedError(
                'MultiStepLR learning rate scheme is enough.')
        print(self.configS['args']['restarts'])
        self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    '''
    The main repo did not use collate_fn and image read has different flags
    and also used np.ascontiguousarray()
    Might change my code if problem happens
    '''

    def feed_data(self, data):
        self.var_L = data['image_lq'].to(self.device)
        self.var_H = data['image'].to(self.device)
        input_ref = data['ref'] if 'ref' in data else data['image']
        self.var_ref = input_ref.to(self.device)

    def optimize_parameters(self, step):
        #Generator
        for p in self.netD.parameters():
            p.requires_grad = False
        self.optimizer_G.zero_grad()
        self.fake_H, self.final_SR, _, _ = self.netG(self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  #pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach(
                )  #don't want to backpropagate this, need proper explanation
                fake_fea = self.netF(
                    self.fake_H)  #In netF normalize=False, check it
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            pred_g_fake = self.netD(self.fake_H)
            if self.configT['gan_type'] == 'gan':
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.configT['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan
            #EESN calculate loss
            if self.cri_charbonnier:  # charbonnier pixel loss HR and SR
                l_e_charbonnier = 5 * self.cri_charbonnier(
                    self.final_SR,
                    self.var_H)  #change the weight to empirically
            l_g_total += l_e_charbonnier

            l_g_total.backward()
            self.optimizer_G.step()

        #descriminator
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_ref)
        pred_d_fake = self.netD(
            self.fake_H.detach())  #to avoid BP to Generator
        if self.configT['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.configT['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real +
                         l_d_fake) / 2  # thinking of adding final sr d loss

        l_d_total.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()
            self.log_dict['l_e_charbonnier'] = l_e_charbonnier.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H, self.final_SR, self.x_learned_lap_fake, self.x_lap = self.netG(
                self.var_L)
            _, _, _, self.x_lap_HR = self.netG(self.var_H)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        #out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['lap_learned'] = self.x_learned_lap_fake.detach()[0].float(
        ).cpu()
        out_dict['lap'] = self.x_lap.detach()[0].float().cpu()
        out_dict['lap_HR'] = self.x_lap_HR.detach()[0].float().cpu()
        out_dict['final_SR'] = self.final_SR.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)

        logger.info('Network G structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

        # Discriminator
        s, n = self.get_network_description(self.netD)
        if isinstance(self.netD, nn.DataParallel) or isinstance(
                self.netD, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netD.__class__.__name__,
                self.netD.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netD.__class__.__name__)

        logger.info('Network D structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

        if self.cri_fea:  # F, Perceptual Network
            s, n = self.get_network_description(self.netF)
            if isinstance(self.netF, nn.DataParallel) or isinstance(
                    self.netF, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netF.__class__.__name__,
                    self.netF.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netF.__class__.__name__)

            logger.info(
                'Network F structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.config['path']['pretrain_model_G']
        if load_path_G:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.config['path']['strict_load'])
        load_path_D = self.config['path']['pretrain_model_D']
        if load_path_D:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.config['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)