Beispiel #1
0
    def init_training_settings(self):
        train_opt = self.opt['train']

        # define network net_d
        self.net_d = networks.define_net_d(deepcopy(self.opt['network_d']))
        self.net_d = self.model_to_device(self.net_d)
        self.print_network(self.net_d)

        # load pretrained model
        load_path = self.opt['path'].get('pretrain_model_d', None)
        if load_path is not None:
            self.load_network(self.net_d, load_path,
                              self.opt['path']['strict_load'])

        # define network net_g with Exponential Moving Average (EMA)
        # net_g_ema only used for testing on one GPU and saving, do not need to
        # wrap with DistributedDataParallel
        self.net_g_ema = networks.define_net_g(deepcopy(
            self.opt['network_g'])).to(self.device)
        # load pretrained model
        load_path = self.opt['path'].get('pretrain_model_g', None)
        if load_path is not None:
            self.load_network(self.net_g_ema, load_path,
                              self.opt['path']['strict_load'], 'params_ema')
        else:
            self.model_ema(0)  # copy net_g weight

        self.net_g.train()
        self.net_d.train()
        self.net_g_ema.eval()

        # define losses
        # gan loss (wgan)
        cri_gan_cls = getattr(loss_module, train_opt['gan_opt'].pop('type'))
        self.cri_gan = cri_gan_cls(**train_opt['gan_opt']).to(self.device)
        # regularization weights
        self.r1_reg_weight = train_opt['r1_reg_weight']  # for discriminator
        self.path_reg_weight = train_opt['path_reg_weight']  # for generator

        self.net_g_reg_every = train_opt['net_g_reg_every']
        self.net_d_reg_every = train_opt['net_d_reg_every']
        self.mixing_prob = train_opt['mixing_prob']

        self.mean_path_length = 0

        # set up optimizers and schedulers
        self.setup_optimizers()
        self.setup_schedulers()
Beispiel #2
0
    def __init__(self, opt):
        super(SRModel, self).__init__(opt)

        # define network
        self.net_g = networks.define_net_g(deepcopy(opt['network_g']))
        self.net_g = self.model_to_device(self.net_g)
        self.print_network(self.net_g)

        # load pretrained models
        load_path = self.opt['path'].get('pretrain_model_g', None)
        if load_path is not None:
            self.load_network(self.net_g, load_path,
                              self.opt['path']['strict_load'])

        if self.is_train:
            self.init_training_settings()
Beispiel #3
0
    def __init__(self, opt):
        super(SRModel, self).__init__(opt)

        # define network
        self.net_g = networks.define_net_g(deepcopy(opt['network_g']))
        self.net_g = self.model_to_device(self.net_g)
        self.print_network(self.net_g)
        self.offset_frame = []
        self.offset_mask = []
        self.avg_feat_gt = []
        self.avg_feat_out = []

        # load pretrained models
        load_path = self.opt['path'].get('pretrain_model_g', None)
        if load_path is not None:
            self.load_network(self.net_g, load_path,
                              self.opt['path']['strict_load'])

        ##load arcface
        load_pth_name = '/home/wei/exp/EDVR/arcfacemodel/model_ir_se50.pth'
        pretrain_dict = torch.load(load_pth_name,
                                   map_location=lambda storage, loc: storage)
        # conv1_weight_new=np.zeros( (64,5,7,7) )
        # conv1_weight_new[:,:3,:,:]=pretrain_dict['conv1.weight'].cpu().data
        # pretrain_dict['conv1.weight']=torch.from_numpy(conv1_weight_new  )

        state_dict = self.net_g.state_dict()
        model_dict = state_dict
        keys = state_dict.keys()
        #    print('state_dict',state_dict)
        for k, v in pretrain_dict.items():
            # kk='backbone.'+k
            kk = k
            for key in keys:
                if kk in key:
                    model_dict[key] = v
                    break
        state_dict.update(model_dict)
        self.net_g.load_state_dict(state_dict)

        # param.requires_grad = True

        if self.is_train:
            self.init_training_settings()
Beispiel #4
0
    def __init__(self, opt):
        super(UNetModel, self).__init__(opt)

        # define network
        self.net_g = networks.define_net_g(deepcopy(opt['network_g']))

        # Define additional buffers
        self.output_transform_for_loss = self.opt['output_transform_for_loss']
        print('Output transform (WT_HF): {}'.format(self.output_transform_for_loss))
        self.net_g = self.model_to_device(self.net_g)
        self.print_network(self.net_g)

        # load pretrained models
        load_path = self.opt['path'].get('pretrain_model_g', None)
        if load_path is not None:
            self.load_network(self.net_g, load_path,
                              self.opt['path']['strict_load'])

        if self.is_train:
            self.init_training_settings()
Beispiel #5
0
    def __init__(self, opt):
        super(StyleGAN2Model, self).__init__(opt)

        # define network net_g
        self.net_g = networks.define_net_g(deepcopy(opt['network_g']))
        self.net_g = self.model_to_device(self.net_g)
        self.print_network(self.net_g)
        # load pretrained model
        load_path = self.opt['path'].get('pretrain_model_g', None)
        if load_path is not None:
            param_key = self.opt['path'].get('param_key_g', 'params')
            self.load_network(self.net_g, load_path,
                              self.opt['path']['strict_load'], param_key)

        # latent dimension: self.num_style_feat
        self.num_style_feat = opt['network_g']['num_style_feat']
        self.fixed_sample = torch.randn(16,
                                        self.num_style_feat,
                                        device=self.device)

        if self.is_train:
            self.init_training_settings()