def test_get_group_allocation_block_diagonal(): mask = torch.ones((4, 4), dtype=torch.bool) mask[2:, :2] = 0 mask[:2, 2:] = 0 gaf, gac = get_group_allocation(mask, G=2) assert gaf is not None assert gac is not None assert (gaf[:2] == 1).all() assert (gaf[2:] == 2).all() assert (gac[:2] == 1).all() assert (gac[2:] == 2).all() # anti-diagonal mask = torch.ones((4, 4), dtype=torch.bool) mask[:2, :2] = 0 mask[2:, 2:] = 0 gaf, gac = get_group_allocation(mask, G=2) assert gaf is not None assert gac is not None assert (gaf[:2] == 2).all() assert (gaf[2:] == 1).all() assert (gac[:2] == 1).all() assert (gac[2:] == 2).all()
def test_get_group_allocation_all_ones(): mask = torch.ones((2, 2), dtype=torch.bool) gaf, gac = get_group_allocation(mask, G=1) assert gaf is not None assert gac is not None assert (gaf == 1).all() assert (gac == 1).all() # Cannot split into valid groups in this case. gaf, gac = get_group_allocation(mask, G=2) assert gaf is None assert gac is None
def test_is_gsp_satisfied(self): """ Test whether GSP can be detected. """ mask_conv = MaskConv2d(32, 32, 3) self.assertTrue(is_gsp_satisfied(mask_conv, 1)) self.assertFalse(is_gsp_satisfied(mask_conv, 2)) mask_conv.mask.data[:16, :16] = 0 mask_conv.mask.data[16:, 16:] = 0 self.assertTrue(is_gsp_satisfied(mask_conv, 2)) mask_conv.mask.data[15, 15] = 1 mask_conv.mask.data[15, 16] = 0 mask_conv.mask.data[16, 15] = 1 mask_conv.mask.data[16, 16] = 0 self.assertFalse(is_gsp_satisfied(mask_conv, 2)) self.assertIsNone(get_group_allocation(mask_conv.mask, 2)[0]) self.assertIsNone(get_group_allocation(mask_conv.mask, 2)[1])
def test_get_group_allocation(self): """ Test GSP based group allocation. """ mask_conv = MaskConv2d(16, 32, 3) # Initially all mask values are 1s. There is no way to split out # two groups. gaf, gac = get_group_allocation(mask_conv.mask, 2) self.assertIsNone(gaf) self.assertIsNone(gac) mask_conv.mask.data[:16, :8] = 0 mask_conv.mask.data[16:, 8:] = 0 gaf, gac = get_group_allocation(mask_conv.mask, 2) self.assertIsNotNone(gac) self.assertIsNotNone(gaf) self.assertTrue(np.allclose(gaf[16:], np.ones(16))) self.assertTrue(np.allclose(gaf[:16], np.ones(16) * 2)) self.assertTrue(np.allclose(gac[:8], np.ones(8))) self.assertTrue(np.allclose(gac[8:], np.ones(8) * 2))
def test_get_group_allocation_scattered(): mask = torch.tensor([ [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], ], dtype=torch.bool) gaf, gac = get_group_allocation(mask, G=2) assert gaf is not None assert gac is not None assert (gaf == [2, 1, 2, 1]).all() assert (gac == [1, 2, 1, 2]).all()