def forward(self, in0, in1, retPerLayer=False): # v0.0 - original release had a bug, where input was not scaled in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk]-feats1[kk])**2 if(self.lpips): if(self.spatial): res = [upsample(self.lins[kk].model(diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] else: if(self.spatial): res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] val = res[0] for l in range(1,self.L): val += res[l] if(retPerLayer): return (val, res) else: return val
def forward(self, in0, in1): in0_sc = self.scaling_layer(in0) in1_sc = self.scaling_layer(in1) if(self.version=='0.0'): # v0.0 - original release had a bug, where input was not scaled in0_input = in0 in1_input = in1 else: # v0.1 in0_input = in0_sc in1_input = in1_sc outs0 = self.net.forward(in0_input) outs1 = self.net.forward(in1_input) feats0 = {} feats1 = {} diffs = [0]*len(outs0) for (kk,out0) in enumerate(outs0): feats0[kk] = util.normalize_tensor(outs0[kk]) feats1[kk] = util.normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk]-feats1[kk])**2 if self.spatial: lin_models = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] if(self.pnet_type=='squeeze'): lin_models.extend([self.lin5, self.lin6]) res = [lin_models[kk].model(diffs[kk]) for kk in range(len(diffs))] return res val = torch.mean(torch.mean(self.lin0.model(diffs[0]),dim=3),dim=2) val = val + torch.mean(torch.mean(self.lin1.model(diffs[1]),dim=3),dim=2) val = val + torch.mean(torch.mean(self.lin2.model(diffs[2]),dim=3),dim=2) val = val + torch.mean(torch.mean(self.lin3.model(diffs[3]),dim=3),dim=2) val = val + torch.mean(torch.mean(self.lin4.model(diffs[4]),dim=3),dim=2) if(self.pnet_type=='squeeze'): val = val + torch.mean(torch.mean(self.lin5.model(diffs[5]),dim=3),dim=2) val = val + torch.mean(torch.mean(self.lin6.model(diffs[6]),dim=3),dim=2) val = val.view(val.size()[0],val.size()[1],1,1) return val