Ejemplo n.º 1
0
    def test_batch_gram_matrix_normalize2(self):
        torch.manual_seed(0)
        tensor_constructors = (torch.ones, torch.rand, torch.randn)

        for constructor in tensor_constructors:
            x = pystiche.batch_gram_matrix(constructor((1, 3, 128, 128)),
                                           normalize=True)
            y = pystiche.batch_gram_matrix(constructor((1, 3, 256, 256)),
                                           normalize=True)

            self.assertTensorAlmostEqual(x, y, atol=2e-2)
Ejemplo n.º 2
0
    def test_GramOperator_call(self):
        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 128, 128)
        input_image = torch.rand(1, 3, 128, 128)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        op = ops.GramOperator(encoder)
        op.set_target_image(target_image)

        actual = op(input_image)
        desired = mse_loss(
            pystiche.batch_gram_matrix(encoder(input_image), normalize=True),
            pystiche.batch_gram_matrix(encoder(target_image), normalize=True),
        )
        self.assertTensorAlmostEqual(actual, desired)
Ejemplo n.º 3
0
    def test_batch_gram_matrix_normalize1(self):
        num_channels = 3

        x = torch.ones((1, num_channels, 128, 128))
        y = pystiche.batch_gram_matrix(x, normalize=True)

        actual = y.flatten()
        desired = torch.ones((num_channels**2, ))
        self.assertTensorAlmostEqual(actual, desired)
Ejemplo n.º 4
0
    def test_batch_gram_matrix(self):
        size = 100

        for dim in (1, 2, 3):
            x = torch.ones((1, 1, *[size] * dim))
            y = pystiche.batch_gram_matrix(x)

            actual = y.item()
            desired = float(size**dim)
            self.assertAlmostEqual(actual, desired)
Ejemplo n.º 5
0
    def test_batch_gram_matrix_size(self):
        batch_size = 1
        num_channels = 3

        torch.manual_seed(0)
        for dim in (1, 2, 3):
            size = (batch_size, num_channels, *torch.randint(256, (dim,)).tolist())
            x = torch.empty(size)
            y = pystiche.batch_gram_matrix(x)

            actual = y.size()
            desired = (batch_size, num_channels, num_channels)
            self.assertTupleEqual(actual, desired)
Ejemplo n.º 6
0
 def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor:
     return pystiche.batch_gram_matrix(enc, normalize=self.normalize)