예제 #1
0
파일: mdsi.py 프로젝트: iampakos/piqa
    def forward(
        self,
        input: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        r"""Defines the computation performed at every call.
        """

        _assert_type(
            [input, target],
            device=self.kernel.device,
            dim_range=(4, 4),
            n_channels=3,
            value_range=(0., self.value_range),
        )

        # Downsample
        _, _, h, w = input.size()
        M = round(min(h, w) / 256)

        if M > 1:
            input = F.avg_pool2d(input, kernel_size=M, ceil_mode=True)
            target = F.avg_pool2d(target, kernel_size=M, ceil_mode=True)

        # RGB to LHM
        input = self.convert(input)
        target = self.convert(target)

        # MDSI
        l = mdsi(input, target, kernel=self.kernel, **self.kwargs)

        return _reduce(l, self.reduction)
예제 #2
0
파일: ssim.py 프로젝트: iampakos/piqa
    def forward(
        self,
        input: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        r"""Defines the computation performed at every call.
        """

        _assert_type(
            [input, target],
            device=self.kernel.device,
            dim_range=(4, 4),
            n_channels=self.kernel.size(0),
            value_range=(0., self.value_range),
        )

        l = ms_ssim(
            input,
            target,
            kernel=self.kernel,
            weights=self.weights,
            **self.kwargs,
        )

        return _reduce(l, self.reduction)
예제 #3
0
    def forward(
        self,
        input: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        r"""Defines the computation performed at every call.
        """

        _assert_type(
            [input, target],
            device=self.shift.device,
            dim_range=(4, 4),
            n_channels=3,
            value_range=(0., 1.) if self.scaling else (0., -1.),
        )

        # ImageNet scaling
        if self.scaling:
            input = (input - self.shift) / self.scale
            target = (target - self.shift) / self.scale

        # LPIPS
        residuals = []

        for lin, fx, fy in zip(self.lins, self.net(input), self.net(target)):
            fx = fx / torch.linalg.norm(fx, dim=1, keepdim=True)
            fy = fy / torch.linalg.norm(fy, dim=1, keepdim=True)

            mse = ((fx - fy)**2).mean(dim=(-1, -2), keepdim=True)
            residuals.append(lin(mse).flatten())

        l = torch.stack(residuals, dim=-1).sum(dim=-1)

        return _reduce(l, self.reduction)
예제 #4
0
    def forward(
        self,
        input: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        r"""Defines the computation performed at every call.
        """

        _assert_type(
            [input, target],
            device=self.kernel.device,
            dim_range=(4, 4),
            n_channels=3,
            value_range=(0., self.value_range),
        )

        # RGB to Y
        input = self.convert(input)
        target = self.convert(target)

        # MS-GMSD
        l = ms_gmsd(
            input,
            target,
            kernel=self.kernel,
            weights=self.weights,
            **self.kwargs,
        )

        return _reduce(l, self.reduction)
예제 #5
0
    def forward(
        self,
        input: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        r"""Defines the computation performed at every call.
        """

        _assert_type(
            [input, target],
            device=self.kernel.device,
            dim_range=(4, 4),
            n_channels=3,
            value_range=(0., self.value_range),
        )

        # Downsample
        input = F.avg_pool2d(input, 2, ceil_mode=True)
        target = F.avg_pool2d(target, 2, ceil_mode=True)

        # RGB to Y
        input = self.convert(input)
        target = self.convert(target)

        # GMSD
        l = gmsd(input, target, kernel=self.kernel, **self.kwargs)

        return _reduce(l, self.reduction)
예제 #6
0
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        r"""Defines the computation performed at every call.
        """

        _assert_type([input], device=input.device, dim_range=(3, -1))

        l = tv(input, **self.kwargs)

        return _reduce(l, self.reduction)
예제 #7
0
    def forward(
        self,
        input: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        r"""Defines the computation performed at every call.
        """

        _assert_type(
            [input, target],
            device=input.device,
            dim_range=(1, -1),
            value_range=(0., self.value_range),
        )

        l = psnr(input, target, **self.kwargs)

        return _reduce(l, self.reduction)