コード例 #1
0
def test_get_group_allocation_block_diagonal():
    mask = torch.ones((4, 4), dtype=torch.bool)
    mask[2:, :2] = 0
    mask[:2, 2:] = 0

    gaf, gac = get_group_allocation(mask, G=2)
    assert gaf is not None
    assert gac is not None
    assert (gaf[:2] == 1).all()
    assert (gaf[2:] == 2).all()
    assert (gac[:2] == 1).all()
    assert (gac[2:] == 2).all()

    # anti-diagonal
    mask = torch.ones((4, 4), dtype=torch.bool)
    mask[:2, :2] = 0
    mask[2:, 2:] = 0

    gaf, gac = get_group_allocation(mask, G=2)
    assert gaf is not None
    assert gac is not None
    assert (gaf[:2] == 2).all()
    assert (gaf[2:] == 1).all()
    assert (gac[:2] == 1).all()
    assert (gac[2:] == 2).all()
コード例 #2
0
def test_get_group_allocation_all_ones():
    mask = torch.ones((2, 2), dtype=torch.bool)
    gaf, gac = get_group_allocation(mask, G=1)

    assert gaf is not None
    assert gac is not None
    assert (gaf == 1).all()
    assert (gac == 1).all()

    # Cannot split into valid groups in this case.
    gaf, gac = get_group_allocation(mask, G=2)
    assert gaf is None
    assert gac is None
コード例 #3
0
    def test_is_gsp_satisfied(self):
        """ Test whether GSP can be detected. """
        mask_conv = MaskConv2d(32, 32, 3)

        self.assertTrue(is_gsp_satisfied(mask_conv, 1))
        self.assertFalse(is_gsp_satisfied(mask_conv, 2))

        mask_conv.mask.data[:16, :16] = 0
        mask_conv.mask.data[16:, 16:] = 0
        self.assertTrue(is_gsp_satisfied(mask_conv, 2))

        mask_conv.mask.data[15, 15] = 1
        mask_conv.mask.data[15, 16] = 0
        mask_conv.mask.data[16, 15] = 1
        mask_conv.mask.data[16, 16] = 0
        self.assertFalse(is_gsp_satisfied(mask_conv, 2))
        self.assertIsNone(get_group_allocation(mask_conv.mask, 2)[0])
        self.assertIsNone(get_group_allocation(mask_conv.mask, 2)[1])
コード例 #4
0
    def test_get_group_allocation(self):
        """ Test GSP based group allocation. """
        mask_conv = MaskConv2d(16, 32, 3)

        # Initially all mask values are 1s. There is no way to split out
        # two groups.
        gaf, gac = get_group_allocation(mask_conv.mask, 2)
        self.assertIsNone(gaf)
        self.assertIsNone(gac)

        mask_conv.mask.data[:16, :8] = 0
        mask_conv.mask.data[16:, 8:] = 0

        gaf, gac = get_group_allocation(mask_conv.mask, 2)
        self.assertIsNotNone(gac)
        self.assertIsNotNone(gaf)
        self.assertTrue(np.allclose(gaf[16:], np.ones(16)))
        self.assertTrue(np.allclose(gaf[:16], np.ones(16) * 2))
        self.assertTrue(np.allclose(gac[:8], np.ones(8)))
        self.assertTrue(np.allclose(gac[8:], np.ones(8) * 2))
コード例 #5
0
def test_get_group_allocation_scattered():
    mask = torch.tensor([
        [0, 1, 0, 1],
        [1, 0, 1, 0],
        [0, 1, 0, 1],
        [1, 0, 1, 0],
    ],
                        dtype=torch.bool)

    gaf, gac = get_group_allocation(mask, G=2)
    assert gaf is not None
    assert gac is not None
    assert (gaf == [2, 1, 2, 1]).all()
    assert (gac == [1, 2, 1, 2]).all()