def forward(self, feature_map1: t.Tensor, feature_map2: t.Tensor) -> t.Tensor: t.Assert( len(feature_map1.shape) == 4, "BUG CHECK: Feature map inputs to FALoss.forward() must have 4 dimensions (B, C, H, W)." ) t.Assert( feature_map1.shape == feature_map2.shape, "BUG CHECK: Feature map inputs to FALoss.forward() should be of same size." ) # Subsample feature map and then calculate matrix similarity S_feature_map1 = FALoss._calculate_matrix_similarity( t.nn.AvgPool2d(self.subsample_factor)(feature_map1)) S_feature_map2 = FALoss._calculate_matrix_similarity( t.nn.AvgPool2d(self.subsample_factor)(feature_map2)) # Create repeats of matrix similarity so that we can calculate L1 norm between each element of one matrix to every other element S_feature_map1 = t.flatten(S_feature_map1, start_dim=2, end_dim=3) S_feature_map1 = t.repeat_interleave(S_feature_map1, repeats=S_feature_map1.shape[-1], dim=2) S_feature_map2 = t.flatten(S_feature_map2, start_dim=2, end_dim=3) S_feature_map2 = S_feature_map2.repeat(1, 1, S_feature_map2.shape[-1]) return F.l1_loss(S_feature_map1, S_feature_map2, reduction=self.reduction)
def forward(self, x): torch.Assert(x.shape[1] > 4, message) return x
def test_assert_true(self): # verify assertions work as expected torch.Assert(True, "foo") with self.assertRaisesRegex(AssertionError, "bar"): torch.Assert(False, "bar")