Exemple #1
0
    def __init__(self,
                 layer_weights,
                 vgg_type='vgg19',
                 use_input_norm=True,
                 perceptual_weight=1.0,
                 style_weight=0.,
                 norm_img=False,
                 criterion='l1'):
        super(PerceptualLoss, self).__init__()
        self.norm_img = norm_img
        self.perceptual_weight = perceptual_weight
        self.style_weight = style_weight
        self.layer_weights = layer_weights
        self.vgg = VGGFeatureExtractor(layer_name_list=list(
            layer_weights.keys()),
                                       vgg_type=vgg_type,
                                       use_input_norm=use_input_norm)

        self.criterion_type = criterion
        if self.criterion_type == 'l1':
            self.criterion = torch.nn.L1Loss()
        elif self.criterion_type == 'l2':
            self.criterion = torch.nn.L2loss()
        elif self.criterion_type == 'fro':
            self.criterion = None
        else:
            raise NotImplementedError(
                f'{criterion} criterion has not been supported.')
Exemple #2
0
    def __init__(self, num_feat, dict_path):
        super().__init__()
        self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
        # part_sizes: [80, 80, 50, 110]
        channel_sizes = [128, 256, 512, 512]
        self.feature_sizes = np.array([256, 128, 64, 32])
        self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
        self.flag_dict_device = False

        # dict
        self.dict = torch.load(dict_path)

        # vgg face extractor
        self.vgg_extractor = VGGFeatureExtractor(
            layer_name_list=self.vgg_layers,
            vgg_type='vgg19',
            use_input_norm=True,
            range_norm=True,
            requires_grad=False)

        # attention block for fusing dictionary features and input features
        self.attn_blocks = nn.ModuleDict()
        for idx, feat_size in enumerate(self.feature_sizes):
            for name in self.parts:
                self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(
                    channel_sizes[idx])

        # multi scale dilation block
        self.multi_scale_dilation = MSDilationBlock(num_feat * 8,
                                                    dilation=[4, 3, 2, 1])

        # upsampling and reconstruction
        self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
        self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
        self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
        self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
        self.upsample4 = nn.Sequential(
            SpectralNorm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)),
            nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
            UpResBlock(num_feat),
            nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh())