def test_save_load_state_dict(self):
        model = copy.deepcopy(self.net)
        x = self.input2.clone()

        model = prepare_binary_model(model, bconfig=self.random_bconfig)
        out1 = model(x)

        binary_state_dict = model.state_dict()

        model = copy.deepcopy(self.net)
        model.apply(self.weight_reset)
        model = prepare_binary_model(model, bconfig=self.random_bconfig)
        model.load_state_dict(binary_state_dict)
        out2 = model(x)

        self.assertTrue(torch.equal(out1, out2))
    def test_linear_layer(self):
        layer = nn.Linear(3, 3, bias=False)
        layer.weight.data.copy_(self.weights.view(3, 3))
        x = self.data[:, :, 0, 0].view(1, 3)
        layer = prepare_binary_model(layer, bconfig=self.test_bconfig)

        output = layer(x)
        expected = torch.tensor([[0.0337, -0.0473, -0.1099]])
        self.assertTrue(torch.allclose(expected, output, atol=1e-4))
    def test_conv2d_layer(self):
        layer = nn.Conv2d(3, 3, 1, bias=False)
        layer.weight.data.copy_(self.weights.view(3, 3, 1, 1))
        x = self.data
        layer = prepare_binary_model(layer, bconfig=self.test_bconfig)

        output = layer(x)
        expected = torch.tensor([[[[0.0337, 0.0337], [0.0337, -0.0337]],
                                  [[-0.0473, -0.0473], [-0.0473, 0.0473]],
                                  [[-0.1099, -0.1099], [-0.1099, 0.1099]]]])
        self.assertTrue(torch.allclose(expected, output, atol=1e-4))
    def test_skip_binarization(self):
        model = copy.copy(self.net)

        fp32_config = BConfig(activation_pre_process=nn.Identity,
                              activation_post_process=nn.Identity,
                              weight_pre_process=nn.Identity)
        model = prepare_binary_model(
            model,
            bconfig=self.random_bconfig,
            custom_config_layers_name={'8': fp32_config})

        cnt_conv, cnt_linear = 0, 0
        for module in model.modules():
            if isinstance(module, Conv2d):
                cnt_conv += 1
            elif isinstance(module, Linear):
                if isinstance(module.activation_pre_process, nn.Identity):
                    cnt_linear += 1

        self.assertEqual(cnt_conv, 2)
        self.assertEqual(cnt_linear, 1)
    def test_many_layers(self):
        model = copy.copy(self.linear_layer)
        model = prepare_binary_model(model, bconfig=self.random_bconfig)

        self.assertEqual(type(model), Linear)
    def test_single_conv2d_layer(self):
        model = copy.copy(self.conv_layer)
        model = prepare_binary_model(model, bconfig=self.random_bconfig)

        self.assertEqual(type(model), Conv2d)
Beispiel #7
0
                                         shuffle=False,
                                         num_workers=2)

# Model
print('==> Building model..')
net = resnet18()

# Binarize
print('==> Preparing the model for binarization')
bconfig = BConfig(activation_pre_process=BasicInputBinarizer,
                  activation_post_process=Identity,
                  weight_pre_process=XNORWeightBinarizer)
# first and last layer will be kept FP32
model = prepare_binary_model(net,
                             bconfig,
                             custom_config_layers_name={
                                 'conv1': BConfig(),
                                 'fc': BConfig()
                             })
print(model)

net = net.to(device)
if 'cuda' in device:
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.pth')
    net.load_state_dict(checkpoint['net'])