def test_linear_layer_wrapper(linear_input, linear_weights, linear_bias, mode,
                              clip_acts, per_channel_wts, expected_output):
    layer = torch.nn.Linear(linear_input.shape[1],
                            expected_output.shape[1],
                            bias=True)
    layer.weight.data = linear_weights
    layer.bias.data = linear_bias

    model = RangeLinearQuantParamLayerWrapper(layer,
                                              8,
                                              8,
                                              mode=mode,
                                              clip_acts=clip_acts,
                                              per_channel_wts=per_channel_wts)

    linear_input = attach_quant_metadata(linear_input,
                                         8,
                                         mode,
                                         stats=None,
                                         clip_mode=clip_acts,
                                         per_channel=False,
                                         num_stds=None,
                                         scale_approx_mult_bits=None)

    with pytest.raises(RuntimeError):
        model(linear_input)

    model.eval()

    output = model(linear_input)

    torch.testing.assert_allclose(output, expected_output)
def test_conv_layer_wrapper(conv_input, conv_weights, mode, clip_acts,
                            per_channel_wts, conv_stats, expected_output):
    layer = torch.nn.Conv2d(conv_input.shape[1],
                            expected_output.shape[1],
                            conv_weights.shape[-1],
                            padding=1,
                            bias=False)
    layer.weight.data = conv_weights

    model = RangeLinearQuantParamLayerWrapper(layer,
                                              8,
                                              8,
                                              mode=mode,
                                              clip_acts=clip_acts,
                                              per_channel_wts=per_channel_wts,
                                              activation_stats=conv_stats)

    input_stats = None if conv_stats is None else conv_stats['inputs'][0]
    conv_input = attach_quant_metadata(conv_input,
                                       8,
                                       mode,
                                       stats=input_stats,
                                       clip_mode=clip_acts,
                                       per_channel=False,
                                       num_stds=None,
                                       scale_approx_mult_bits=None)

    with pytest.raises(RuntimeError):
        model(conv_input)

    model.eval()

    output = model(conv_input)

    torch.testing.assert_allclose(output, expected_output)
Beispiel #3
0
def test_conv_layer_wrapper(conv_input, conv_weights, mode, clip_acts, per_channel_wts, conv_stats, expected_output):
    layer = torch.nn.Conv2d(conv_input.shape[1], expected_output.shape[1], conv_weights.shape[-1],
                            padding=1, bias=False)
    layer.weight.data = conv_weights

    model = RangeLinearQuantParamLayerWrapper(layer, 8, 8, mode=mode, clip_acts=clip_acts,
                                              per_channel_wts=per_channel_wts, activation_stats=conv_stats)

    with pytest.raises(RuntimeError):
        model(conv_input)

    model.eval()

    output = model(conv_input)

    torch.testing.assert_allclose(output, expected_output)
Beispiel #4
0
def test_linear_layer_wrapper(linear_input, linear_weights, linear_bias,
                              mode, clip_acts, per_channel_wts, expected_output):
    layer = torch.nn.Linear(linear_input.shape[1], expected_output.shape[1], bias=True)
    layer.weight.data = linear_weights
    layer.bias.data = linear_bias

    model = RangeLinearQuantParamLayerWrapper(layer, 8, 8, mode=mode, clip_acts=clip_acts,
                                              per_channel_wts=per_channel_wts)

    with pytest.raises(RuntimeError):
        model(linear_input)

    model.eval()

    output = model(linear_input)

    torch.testing.assert_allclose(output, expected_output)
Beispiel #5
0
def _test_wts_only_quant(layer, x, per_channel, bias, num_bits):
    layer.weight.data = torch.rand_like(layer.weight)
    if bias:
        layer.bias.data = torch.rand_like(layer.bias)
    mode = LinearQuantMode.ASYMMETRIC_UNSIGNED

    layer_ptq = RangeLinearQuantParamLayerWrapper(deepcopy(layer), None, num_bits, mode=mode, per_channel_wts=per_channel)
    layer_ptq.eval()

    layer_manual_q = deepcopy(layer)
    _fake_quant_tensor(layer_manual_q.weight.data, num_bits, mode, per_channel)
    assert torch.equal(layer_ptq.wrapped_module.weight, layer_manual_q.weight)
    if bias:
        _fake_quant_tensor(layer_manual_q.bias.data, num_bits, mode, False)
        assert torch.equal(layer_ptq.wrapped_module.bias, layer_manual_q.bias)

    y_ptq = layer_ptq(x)
    y_manual_q = layer_manual_q(x)

    assert torch.equal(y_ptq, y_manual_q)