def STSIM(self, img1, img2, sub_sample=True): assert img1.shape == img2.shape assert len(img1.shape) == 4 # [N,C,H,W] assert img1.shape[1] == 1 # gray image s = SCFpyr_PyTorch(sub_sample=sub_sample, device=self.device) pyrA = s.getlist(s.build(img1)) pyrB = s.getlist(s.build(img2)) stsim = map(self.pooling, pyrA, pyrB) return torch.mean(torch.stack(list(stsim)), dim=0)
def STSIM2(self, img1, img2): assert img1.shape == img2.shape s = SCFpyr_PyTorch(sub_sample=True, device=self.device) s_nosub = SCFpyr_PyTorch(sub_sample=False, device=self.device) pyrA = s.getlist(s.build(img1)) pyrB = s.getlist(s.build(img2)) stsimg2 = list(map(self.pooling, pyrA, pyrB)) # Add cross terms bandsAn = s_nosub.build(img1) bandsBn = s_nosub.build(img2) Nor = len(bandsAn[1]) # Accross scale, same orientation for scale in range(2, len(bandsAn) - 1): for orient in range(Nor): img11 = self.abs(bandsAn[scale - 1][orient]) img12 = self.abs(bandsAn[scale][orient]) img21 = self.abs(bandsBn[scale - 1][orient]) img22 = self.abs(bandsBn[scale][orient]) stsimg2.append( self.compute_cross_term(img11, img12, img21, img22).mean(dim=[1, 2, 3])) # Accross orientation, same scale for scale in range(1, len(bandsAn) - 1): for orient in range(Nor - 1): img11 = self.abs(bandsAn[scale][orient]) img21 = self.abs(bandsBn[scale][orient]) for orient2 in range(orient + 1, Nor): img13 = self.abs(bandsAn[scale][orient2]) img23 = self.abs(bandsBn[scale][orient2]) stsimg2.append( self.compute_cross_term(img11, img13, img21, img23).mean(dim=[1, 2, 3])) return torch.mean(torch.stack(stsimg2), dim=0)
def STSIM_M(self, imgs): ''' :param imgs: [N,C=1,H,W] :return: ''' s = SCFpyr_PyTorch(sub_sample=True, device=self.device) coeffs = s.build(imgs) f = [] # single subband statistics for c in s.getlist(coeffs): c = self.abs(c) var = torch.var(c, dim=[1, 2, 3]) f.append(torch.mean(c, dim=[1, 2, 3])) f.append(var) f.append( torch.mean(c[:, :, :-1, :] * c[:, :, 1:, :], dim=[1, 2, 3]) / var) f.append( torch.mean(c[:, :, :, :-1] * c[:, :, :, 1:], dim=[1, 2, 3]) / var) # correlation statistics # across orientations for orients in coeffs[1:-1]: for (c1, c2) in list(itertools.combinations(orients, 2)): c1 = self.abs(c1) c2 = self.abs(c2) f.append(torch.mean(c1 * c2, dim=[1, 2, 3])) for orient in range(len(coeffs[1])): for height in range(len(coeffs) - 3): c1 = self.abs(coeffs[height + 1][orient]) c2 = self.abs(coeffs[height + 2][orient]) c1 = F.interpolate(c1, size=c2.shape[2:]) f.append( torch.mean(c1 * c2, dim=[1, 2, 3]) / torch.sqrt(torch.var(c1, dim=[1, 2, 3])) / torch.sqrt(torch.var(c2, dim=[1, 2, 3]))) return torch.stack(f)