Exemple #1
0
    def forward(self, in0, in1, retPerLayer=False):
        # v0.0 - original release had a bug, where input was not scaled
        in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
        feats0, feats1, diffs = {}, {}, {}

        for kk in range(self.L):
            feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk]-feats1[kk])**2

        if(self.lpips):
            if(self.spatial):
                res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
            else:
                res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
        else:
            if(self.spatial):
                res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
            else:
                res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]

        val = res[0]
        for l in range(1,self.L):
            val += res[l]
        
        if(retPerLayer):
            return (val, res)
        else:
            return val
Exemple #2
0
    def forward(self, in0, in1, retPerLayer=False, normalize=False):
        if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
            in0 = 2 * in0  - 1
            in1 = 2 * in1  - 1

        # v0.0 - original release had a bug, where input was not scaled
        in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
        feats0, feats1, diffs = {}, {}, {}

        for kk in range(self.L):
            feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk]-feats1[kk])**2

        if(self.lpips):
            if(self.spatial):
                res = [upsample(self.lins[kk].model(diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
            else:
                res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
        else:
            if(self.spatial):
                res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
            else:
                res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]

        val = res[0]
        for l in range(1,self.L):
            val += res[l]
        
        if(retPerLayer):
            return (val, res)
        else:
            return val
Exemple #3
0
    def forward(self, in0, in1, layers=None):
        layers = layers or self.layers
        outs0, outs1 = self.net(in0), self.net(in1)
        feats0, feats1, diffs = {}, {}, {}

        for kk, key in enumerate(outs0._asdict().keys()):
            if layers is not None and key not in layers:
                continue
            feats0[kk] = lpips.normalize_tensor(outs0[kk])
            feats1[kk] = lpips.normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk] - feats1[kk])**2

        res = torch.stack([
            spatial_average(diffs[kk].sum(dim=1, keepdim=True)).sum()
            for kk in diffs.keys()
        ])
        return res.mean()
Exemple #4
0
    def forward(self, in0, in1, retPerLayer=False):
        # pdb.set_trace()
        # (Pdb) in0.size() -- torch.Size([1, 3, 256, 256]), 
        # (Pdb) in1.size() -- torch.Size([1, 3, 256, 256])
        # (Pdb) in0.mean().item(), in0.min().item(), in0.max().item()
        # (0.3927803933620453, 0.007843137718737125, 1.0)
        
        # v0.0 - original release had a bug, where input was not scaled

        in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
        feats0, feats1, diffs = {}, {}, {}

        for kk in range(self.L):
            feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk]-feats1[kk])**2

        if(self.lpips):
            if(self.spatial):
                res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
            else:
                res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
        else:
            if(self.spatial):
                res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
            else:
                res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]

        val = res[0]
        for l in range(1,self.L):
            val += res[l]
        
        if(retPerLayer):
            return (val, res)
        else:
            return val
    def forward(self, inp):
        if self.normalize:  # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
            inp = 2 * inp - 1

        inp_input = self.ps.scaling_layer(
            inp) if self.ps.version == '0.1' else inp
        outs = self.ps.net.forward(inp_input)
        feats = []
        for kk in range(self.ps.L):
            h, w = outs[kk].shape[2], outs[kk].shape[3]
            assert h == w
            out_res = self.layer_res[
                self.ps.pnet_type][kk] if self.pooling else h
            if out_res > 0:
                # Reduce spatial size
                feats_kk = outs[kk] if h == out_res else F.adaptive_avg_pool2d(
                    outs[kk], out_res)
                # Normalize and apply the learned weights
                feats_kk = self.ps.lins[kk].model[-1].weight.data.pow(
                    0.5) * lpips.normalize_tensor(feats_kk)
                # Flatten all spatial dimensions and divide by resolution due to HW normalization in eq. 1
                feats.append(feats_kk.reshape(feats_kk.shape[0], -1) / out_res)
        # Return the concatenated feature
        return torch.cat(feats, dim=1)