def forward(self, in0, in1, retPerLayer=False): # v0.0 - original release had a bug, where input was not scaled in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk]-feats1[kk])**2 if(self.lpips): if(self.spatial): res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] else: res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] else: if(self.spatial): res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] else: res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] val = res[0] for l in range(1,self.L): val += res[l] if(retPerLayer): return (val, res) else: return val
def forward(self, in0, in1, retPerLayer=False, normalize=False): if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 # v0.0 - original release had a bug, where input was not scaled in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk]-feats1[kk])**2 if(self.lpips): if(self.spatial): res = [upsample(self.lins[kk].model(diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] else: if(self.spatial): res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] val = res[0] for l in range(1,self.L): val += res[l] if(retPerLayer): return (val, res) else: return val
def forward(self, in0, in1, layers=None): layers = layers or self.layers outs0, outs1 = self.net(in0), self.net(in1) feats0, feats1, diffs = {}, {}, {} for kk, key in enumerate(outs0._asdict().keys()): if layers is not None and key not in layers: continue feats0[kk] = lpips.normalize_tensor(outs0[kk]) feats1[kk] = lpips.normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk] - feats1[kk])**2 res = torch.stack([ spatial_average(diffs[kk].sum(dim=1, keepdim=True)).sum() for kk in diffs.keys() ]) return res.mean()
def forward(self, in0, in1, retPerLayer=False): # pdb.set_trace() # (Pdb) in0.size() -- torch.Size([1, 3, 256, 256]), # (Pdb) in1.size() -- torch.Size([1, 3, 256, 256]) # (Pdb) in0.mean().item(), in0.min().item(), in0.max().item() # (0.3927803933620453, 0.007843137718737125, 1.0) # v0.0 - original release had a bug, where input was not scaled in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk]-feats1[kk])**2 if(self.lpips): if(self.spatial): res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] else: res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] else: if(self.spatial): res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] else: res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] val = res[0] for l in range(1,self.L): val += res[l] if(retPerLayer): return (val, res) else: return val
def forward(self, inp): if self.normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] inp = 2 * inp - 1 inp_input = self.ps.scaling_layer( inp) if self.ps.version == '0.1' else inp outs = self.ps.net.forward(inp_input) feats = [] for kk in range(self.ps.L): h, w = outs[kk].shape[2], outs[kk].shape[3] assert h == w out_res = self.layer_res[ self.ps.pnet_type][kk] if self.pooling else h if out_res > 0: # Reduce spatial size feats_kk = outs[kk] if h == out_res else F.adaptive_avg_pool2d( outs[kk], out_res) # Normalize and apply the learned weights feats_kk = self.ps.lins[kk].model[-1].weight.data.pow( 0.5) * lpips.normalize_tensor(feats_kk) # Flatten all spatial dimensions and divide by resolution due to HW normalization in eq. 1 feats.append(feats_kk.reshape(feats_kk.shape[0], -1) / out_res) # Return the concatenated feature return torch.cat(feats, dim=1)