def test_batch_gram_matrix_normalize2(self): torch.manual_seed(0) tensor_constructors = (torch.ones, torch.rand, torch.randn) for constructor in tensor_constructors: x = pystiche.batch_gram_matrix(constructor((1, 3, 128, 128)), normalize=True) y = pystiche.batch_gram_matrix(constructor((1, 3, 256, 256)), normalize=True) self.assertTensorAlmostEqual(x, y, atol=2e-2)
def test_GramOperator_call(self): torch.manual_seed(0) target_image = torch.rand(1, 3, 128, 128) input_image = torch.rand(1, 3, 128, 128) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.GramOperator(encoder) op.set_target_image(target_image) actual = op(input_image) desired = mse_loss( pystiche.batch_gram_matrix(encoder(input_image), normalize=True), pystiche.batch_gram_matrix(encoder(target_image), normalize=True), ) self.assertTensorAlmostEqual(actual, desired)
def test_batch_gram_matrix_normalize1(self): num_channels = 3 x = torch.ones((1, num_channels, 128, 128)) y = pystiche.batch_gram_matrix(x, normalize=True) actual = y.flatten() desired = torch.ones((num_channels**2, )) self.assertTensorAlmostEqual(actual, desired)
def test_batch_gram_matrix(self): size = 100 for dim in (1, 2, 3): x = torch.ones((1, 1, *[size] * dim)) y = pystiche.batch_gram_matrix(x) actual = y.item() desired = float(size**dim) self.assertAlmostEqual(actual, desired)
def test_batch_gram_matrix_size(self): batch_size = 1 num_channels = 3 torch.manual_seed(0) for dim in (1, 2, 3): size = (batch_size, num_channels, *torch.randint(256, (dim,)).tolist()) x = torch.empty(size) y = pystiche.batch_gram_matrix(x) actual = y.size() desired = (batch_size, num_channels, num_channels) self.assertTupleEqual(actual, desired)
def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: return pystiche.batch_gram_matrix(enc, normalize=self.normalize)