Exemplo n.º 1
0
    def test_prune_module(self):
        conv = MaskConv2d(32, 32, 3)
        G = 4
        prune_utils.prune_module(conv, G)

        self.assertFalse(torch.allclose(conv.mask,
                                        torch.ones(conv.mask.shape)))
Exemplo n.º 2
0
    def test_mask_assign(self):
        """ Assign value to the internal mask. """
        mask_conv2d = MaskConv2d(32, 32, 3, bias=False)
        mask_conv2d.mask.data = torch.zeros((32, 32))
        x = torch.randn((1, 32, 4, 4))
        result = mask_conv2d.forward(x)

        self.assertTrue(torch.allclose(result, torch.zeros(result.shape)))
Exemplo n.º 3
0
def conv1x1(in_planes, out_planes, stride=1, groups=1, indices=None,
            mask=False):
  """ 1x1 convolution """
  if not mask:
    return GroupConv2d(
        in_planes,
        out_planes,
        kernel_size=1,
        stride=stride,
        num_groups=groups,
        indices=_get_indices(indices, out_planes, num_groups=groups))
  else:
    return MaskConv2d(
        in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
Exemplo n.º 4
0
    def test_ctor(self):
        # now mask is all ONE
        mask_conv2d = MaskConv2d(32, 32, 3, bias=False)
        conv2d = nn.Conv2d(32, 32, 3, bias=False)

        # assign the same weight
        weight = torch.randn((32, 32, 3, 3))
        conv2d.weight.data = weight
        mask_conv2d.weight.data = weight

        # check equivalence
        x = torch.randn((1, 32, 4, 4))
        result = mask_conv2d.forward(x)
        golden = conv2d.forward(x)

        self.assertTrue(torch.allclose(result, golden))
Exemplo n.º 5
0
def conv3x3(in_planes, out_planes, groups=1, stride=1, indices=None, mask=False):
    """ 3x3 convolution with padding, support mask and group """
    if not mask:
        return GroupConv2d(
            in_planes,
            out_planes,
            kernel_size=3,
            stride=stride,
            padding=1,
            num_groups=groups,
            indices=_get_indices(indices, out_planes, num_groups=groups),
        )
    else:
        return MaskConv2d(
            in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
Exemplo n.º 6
0
 def test_params(self):
     mask_conv2d = MaskConv2d(32, 32, 3)
     self.assertEqual(len(list(mask_conv2d.parameters())), 2)
     mask_conv2d = MaskConv2d(32, 32, 3, bias=True)
     self.assertEqual(len(list(mask_conv2d.parameters())), 3)
Exemplo n.º 7
0
    def __init__(self):
        super().__init__()

        # suppose the input image has shape 32 x 4 x 4
        self.conv = MaskConv2d(32, 32, 3)
        self.fc = nn.Linear(32 * 2 * 2, 10)