def forward(self, feature_map1: t.Tensor,
                feature_map2: t.Tensor) -> t.Tensor:
        t.Assert(
            len(feature_map1.shape) == 4,
            "BUG CHECK: Feature map inputs to FALoss.forward() must have 4 dimensions (B, C, H, W)."
        )
        t.Assert(
            feature_map1.shape == feature_map2.shape,
            "BUG CHECK: Feature map inputs to FALoss.forward() should be of same size."
        )

        # Subsample feature map and then calculate matrix similarity
        S_feature_map1 = FALoss._calculate_matrix_similarity(
            t.nn.AvgPool2d(self.subsample_factor)(feature_map1))
        S_feature_map2 = FALoss._calculate_matrix_similarity(
            t.nn.AvgPool2d(self.subsample_factor)(feature_map2))

        # Create repeats of matrix similarity so that we can calculate L1 norm between each element of one matrix to every other element
        S_feature_map1 = t.flatten(S_feature_map1, start_dim=2, end_dim=3)
        S_feature_map1 = t.repeat_interleave(S_feature_map1,
                                             repeats=S_feature_map1.shape[-1],
                                             dim=2)
        S_feature_map2 = t.flatten(S_feature_map2, start_dim=2, end_dim=3)
        S_feature_map2 = S_feature_map2.repeat(1, 1, S_feature_map2.shape[-1])

        return F.l1_loss(S_feature_map1,
                         S_feature_map2,
                         reduction=self.reduction)
Exemple #2
0
 def forward(self, x):
     torch.Assert(x.shape[1] > 4, message)
     return x
Exemple #3
0
 def test_assert_true(self):
     # verify assertions work as expected
     torch.Assert(True, "foo")
     with self.assertRaisesRegex(AssertionError, "bar"):
         torch.Assert(False, "bar")