def STSIM(self, img1, img2, sub_sample=True):
        assert img1.shape == img2.shape
        assert len(img1.shape) == 4  # [N,C,H,W]
        assert img1.shape[1] == 1  # gray image

        s = SCFpyr_PyTorch(sub_sample=sub_sample, device=self.device)

        pyrA = s.getlist(s.build(img1))
        pyrB = s.getlist(s.build(img2))

        stsim = map(self.pooling, pyrA, pyrB)

        return torch.mean(torch.stack(list(stsim)), dim=0)
    def STSIM2(self, img1, img2):
        assert img1.shape == img2.shape

        s = SCFpyr_PyTorch(sub_sample=True, device=self.device)
        s_nosub = SCFpyr_PyTorch(sub_sample=False, device=self.device)

        pyrA = s.getlist(s.build(img1))
        pyrB = s.getlist(s.build(img2))
        stsimg2 = list(map(self.pooling, pyrA, pyrB))

        # Add cross terms
        bandsAn = s_nosub.build(img1)
        bandsBn = s_nosub.build(img2)

        Nor = len(bandsAn[1])

        # Accross scale, same orientation
        for scale in range(2, len(bandsAn) - 1):
            for orient in range(Nor):
                img11 = self.abs(bandsAn[scale - 1][orient])
                img12 = self.abs(bandsAn[scale][orient])

                img21 = self.abs(bandsBn[scale - 1][orient])
                img22 = self.abs(bandsBn[scale][orient])

                stsimg2.append(
                    self.compute_cross_term(img11, img12, img21,
                                            img22).mean(dim=[1, 2, 3]))

        # Accross orientation, same scale
        for scale in range(1, len(bandsAn) - 1):
            for orient in range(Nor - 1):
                img11 = self.abs(bandsAn[scale][orient])
                img21 = self.abs(bandsBn[scale][orient])

                for orient2 in range(orient + 1, Nor):
                    img13 = self.abs(bandsAn[scale][orient2])
                    img23 = self.abs(bandsBn[scale][orient2])
                    stsimg2.append(
                        self.compute_cross_term(img11, img13, img21,
                                                img23).mean(dim=[1, 2, 3]))

        return torch.mean(torch.stack(stsimg2), dim=0)
    def STSIM_M(self, imgs):
        '''
		:param imgs: [N,C=1,H,W]
		:return:
		'''
        s = SCFpyr_PyTorch(sub_sample=True, device=self.device)
        coeffs = s.build(imgs)

        f = []
        # single subband statistics
        for c in s.getlist(coeffs):
            c = self.abs(c)
            var = torch.var(c, dim=[1, 2, 3])
            f.append(torch.mean(c, dim=[1, 2, 3]))
            f.append(var)
            f.append(
                torch.mean(c[:, :, :-1, :] * c[:, :, 1:, :], dim=[1, 2, 3]) /
                var)
            f.append(
                torch.mean(c[:, :, :, :-1] * c[:, :, :, 1:], dim=[1, 2, 3]) /
                var)

        # correlation statistics
        # across orientations
        for orients in coeffs[1:-1]:
            for (c1, c2) in list(itertools.combinations(orients, 2)):
                c1 = self.abs(c1)
                c2 = self.abs(c2)
                f.append(torch.mean(c1 * c2, dim=[1, 2, 3]))

        for orient in range(len(coeffs[1])):
            for height in range(len(coeffs) - 3):
                c1 = self.abs(coeffs[height + 1][orient])
                c2 = self.abs(coeffs[height + 2][orient])

                c1 = F.interpolate(c1, size=c2.shape[2:])
                f.append(
                    torch.mean(c1 * c2, dim=[1, 2, 3]) /
                    torch.sqrt(torch.var(c1, dim=[1, 2, 3])) /
                    torch.sqrt(torch.var(c2, dim=[1, 2, 3])))
        return torch.stack(f)