예제 #1
0
    def define_loss(self):
        # ------------------------------------
        # G_loss
        # ------------------------------------
        if self.opt_train['G_lossfn_weight'] > 0:
            G_lossfn_type = self.opt_train['G_lossfn_type']
            if G_lossfn_type == 'l1':
                self.G_lossfn = nn.L1Loss().to(self.device)
            elif G_lossfn_type == 'l2':
                self.G_lossfn = nn.MSELoss().to(self.device)
            elif G_lossfn_type == 'l2sum':
                self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
            elif G_lossfn_type == 'ssim':
                self.G_lossfn = SSIMLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not found.'.format(G_lossfn_type))
            self.G_lossfn_weight = self.opt_train['G_lossfn_weight']
        else:
            print('Do not use pixel loss.')
            self.G_lossfn = None

        # ------------------------------------
        # F_loss
        # ------------------------------------
        if self.opt_train['F_lossfn_weight'] > 0:
            F_lossfn_type = self.opt_train['F_lossfn_type']
            F_use_input_norm = self.opt_train['F_use_input_norm']
            F_feature_layer = self.opt_train['F_feature_layer']
            if self.opt['dist']:
                self.F_lossfn = PerceptualLoss(feature_layer=F_feature_layer,
                                               use_input_norm=F_use_input_norm,
                                               lossfn_type=F_lossfn_type).to(
                                                   self.device)
            else:
                self.F_lossfn = PerceptualLoss(feature_layer=F_feature_layer,
                                               use_input_norm=F_use_input_norm,
                                               lossfn_type=F_lossfn_type)
                self.F_lossfn.vgg = self.model_to_device(self.F_lossfn.vgg)
                self.F_lossfn.lossfn = self.F_lossfn.lossfn.to(self.device)
            self.F_lossfn_weight = self.opt_train['F_lossfn_weight']
        else:
            print('Do not use feature loss.')
            self.F_lossfn = None

        # ------------------------------------
        # D_loss
        # ------------------------------------
        self.D_lossfn = GANLoss(self.opt_train['gan_type'], 1.0,
                                0.0).to(self.device)
        self.D_lossfn_weight = self.opt_train['D_lossfn_weight']

        self.D_update_ratio = self.opt_train[
            'D_update_ratio'] if self.opt_train['D_update_ratio'] else 1
        self.D_init_iters = self.opt_train['D_init_iters'] if self.opt_train[
            'D_init_iters'] else 0
예제 #2
0
파일: model_plain4.py 프로젝트: znsc/KAIR
 def define_loss(self):
     G_lossfn_type = self.opt_train['G_lossfn_type']
     if G_lossfn_type == 'l1':
         self.G_lossfn = nn.L1Loss().to(self.device)
     elif G_lossfn_type == 'l2':
         self.G_lossfn = nn.MSELoss().to(self.device)
     elif G_lossfn_type == 'l2sum':
         self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
     elif G_lossfn_type == 'ssim':
         self.G_lossfn = SSIMLoss().to(self.device)
     else:
         raise NotImplementedError('Loss type [{:s}] is not found.'.format(G_lossfn_type))
     self.G_lossfn_weight = self.opt_train['G_lossfn_weight']
예제 #3
0
    def define_loss(self):
        # ------------------------------------
        # G_loss
        # ------------------------------------
        if self.opt_train['G_lossfn_weight'] > 0:
            G_lossfn_type = self.opt_train['G_lossfn_type']
            if G_lossfn_type == 'l1':
                self.G_lossfn = nn.L1Loss().to(self.device)
            elif G_lossfn_type == 'l2':
                self.G_lossfn = nn.MSELoss().to(self.device)
            elif G_lossfn_type == 'l2sum':
                self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
            elif G_lossfn_type == 'ssim':
                self.G_lossfn = SSIMLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not found.'.format(G_lossfn_type))
            self.G_lossfn_weight = self.opt_train['G_lossfn_weight']
        else:
            print('Do not use pixel loss.')
            self.G_lossfn = None

        # ------------------------------------
        # F_loss
        # ------------------------------------
        if self.opt_train['F_lossfn_weight'] > 0:
            F_lossfn_type = self.opt_train['F_lossfn_type']
            if F_lossfn_type == 'l1':
                self.F_lossfn = nn.L1Loss().to(self.device)
            elif F_lossfn_type == 'l2':
                self.F_lossfn = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(F_lossfn_type))
            self.F_lossfn_weight = self.opt_train['F_lossfn_weight']
            # self.netF = define_F(self.opt, use_bn=False).to(self.device)
        else:
            print('Do not use feature loss.')
            self.F_lossfn = None

        # ------------------------------------
        # D_loss
        # ------------------------------------
        self.D_lossfn = GANLoss(self.opt_train['gan_type'], 1.0,
                                0.0).to(self.device)
        self.D_lossfn_weight = self.opt_train['D_lossfn_weight']

        self.D_update_ratio = self.opt_train[
            'D_update_ratio'] if self.opt_train['D_update_ratio'] else 1
        self.D_init_iters = self.opt_train['D_init_iters'] if self.opt_train[
            'D_init_iters'] else 0