def get_expected_qbits(model, qbits, expected_overrides):
    expected_qbits = {}
    post_prepare_changes = {}
    prefix = 'module.' if isinstance(model, torch.nn.DataParallel) else ''
    for orig_name, orig_module in model.named_modules():
        bits_a, bits_w, bits_b = expected_overrides.get(orig_name.replace(prefix, '', 1), qbits)
        if not params_quantizable(orig_module):
            bits_w = bits_b = None
        expected_qbits[orig_name] = QBits(bits_a, bits_w, bits_b)

        # We're testing replacement of module with container
        if isinstance(orig_module, (nn.Conv2d, nn.Linear)):
            post_prepare_changes[orig_name] = QBits(bits_a, None, None)
            post_prepare_changes[orig_name + '.inner'] = expected_qbits[orig_name]

    return expected_qbits, post_prepare_changes
Beispiel #2
0
def get_expected_qbits(model, qbits, expected_overrides):
    expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer, nn.Linear: DummyWrapperLayer}

    expected_qbits = OrderedDict()
    post_prepare_qbbits_changes = OrderedDict()
    post_prepare_expected_types = OrderedDict()
    prefix = 'module.' if isinstance(model, torch.nn.DataParallel) else ''
    for orig_name, orig_module in model.named_modules():
        orig_module_type = type(orig_module)
        bits_a, bits_w, bits_b = expected_overrides.get(orig_name.replace(prefix, '', 1), qbits)
        if not params_quantizable(orig_module):
            bits_w = bits_b = None
        expected_qbits[orig_name] = QBits(bits_a, bits_w, bits_b)
        if expected_qbits[orig_name] == QBits(None, None, None):
            post_prepare_expected_types[orig_name] = orig_module_type
        else:
            post_prepare_expected_types[orig_name] = expected_type_replacements.get(orig_module_type, orig_module_type)
            # We're testing replacement of module with container
            if post_prepare_expected_types[orig_name] == DummyWrapperLayer:
                post_prepare_qbbits_changes[orig_name] = QBits(bits_a, None, None)
                post_prepare_qbbits_changes[orig_name + '.inner'] = expected_qbits[orig_name]
                post_prepare_expected_types[orig_name + '.inner'] = orig_module_type

    return expected_qbits, post_prepare_qbbits_changes, post_prepare_expected_types
Beispiel #3
0

def test_overrides_ordered_dict(model):
    pytest_raises_wrapper(TypeError, 'Expecting TypeError when overrides is not an OrderedDict',
                          DummyQuantizer, model, overrides={'testing': {'testing': '123'}})


acts_key = 'bits_activations'
wts_key = 'bits_weights'
bias_key = 'bits_bias'


@pytest.mark.parametrize(
    "qbits, overrides, explicit_expected_overrides",
    [
        (QBits(8, 4, 32), OrderedDict(), {}),
        (QBits(8, 4, 32),
         OrderedDict([('conv1', {acts_key: None, wts_key: None, bias_key: None}),
                      ('relu1', {acts_key: None, wts_key: None, bias_key: None})]),
         {'conv1': QBits(None, None, None), 'relu1': QBits(None, None, None)}),
        (QBits(8, 8, 32),
         OrderedDict([('sub.*conv1', {wts_key: 4}), ('sub.*conv2', {acts_key: 4, wts_key: 4})]),
         {'sub1.conv1': QBits(8, 4, 32), 'sub1.conv2': QBits(4, 4, 32), 'sub2.conv1': QBits(8, 4, 32), 'sub2.conv2': QBits(4, 4, 32)}),
        (QBits(4, 4, 32),
         OrderedDict([('sub1\..*1', {acts_key: 16, wts_key: 16}), ('sub1\..*', {acts_key: 8, wts_key: 8})]),
         {'sub1.conv1': QBits(16, 16, 32), 'sub1.bn1': QBits(16, None, None),
          'sub1.relu1': QBits(16, None, None), 'sub1.pool1': QBits(16, None, None),
          'sub1.conv2': QBits(8, 8, 32), 'sub1.bn2': QBits(8, None, None),
          'sub1.relu2': QBits(8, None, None), 'sub1.pool2': QBits(8, None, None)}),
        (QBits(4, 4, 32),
         OrderedDict([('sub1\..*', {acts_key: 8, wts_key: 8}), ('sub1\..*1', {acts_key: 16, wts_key: 16})]),
Beispiel #4
0
def test_model_prep(model, optimizer, qbits, bits_overrides,
                    explicit_expected_overrides, train_with_fp_copy,
                    quantize_bias, parallel):
    if parallel:
        model = torch.nn.DataParallel(model)
    m_orig = deepcopy(model)

    # Build expected QBits
    expected_qbits, post_prepare_changes = get_expected_qbits(
        model, qbits, explicit_expected_overrides)

    # Initialize Quantizer
    q = DummyQuantizer(model,
                       optimizer=optimizer,
                       bits_activations=qbits.acts,
                       bits_weights=qbits.wts,
                       bits_overrides=deepcopy(bits_overrides),
                       train_with_fp_copy=train_with_fp_copy,
                       quantize_bias=quantize_bias)

    # Check number of bits for quantization were registered correctly
    assert q.module_qbits_map == expected_qbits

    q.prepare_model()
    expected_qbits.update(post_prepare_changes)

    for ptq in q.params_to_quantize:
        assert params_quantizable(ptq.module)
        assert expected_qbits[ptq.module_name].wts is not None

        # Check parameter names are as expected
        assert ptq.q_attr_name in ['weight', 'bias']

        # Check bias will be quantized only if flag is enabled
        if ptq.q_attr_name == 'bias':
            assert quantize_bias

        named_params = dict(ptq.module.named_parameters())
        if q.train_with_fp_copy:
            # Checking parameter replacement is as expected
            assert ptq.fp_attr_name == FP_BKP_PREFIX + ptq.q_attr_name
            assert ptq.fp_attr_name in named_params
            assert ptq.q_attr_name not in named_params
            # Making sure the following doesn't throw an exception,
            # so we know q_attr_name is still a buffer in the module
            getattr(ptq.module, ptq.q_attr_name)
        else:
            # Make sure we didn't screw anything up
            assert ptq.fp_attr_name == ptq.q_attr_name
            assert ptq.fp_attr_name in named_params

        # Check number of bits registered correctly
        # Bias number of bits is hard-coded to 32 for now...
        expected_n_bits = 32 if ptq.q_attr_name == 'bias' else expected_qbits[
            ptq.module_name].wts
        assert ptq.num_bits == expected_n_bits

    q_named_modules = dict(model.named_modules())
    orig_named_modules = dict(m_orig.named_modules())
    for orig_name, orig_module in orig_named_modules.items():
        # Check no module name from original model is missing
        assert orig_name in q_named_modules

        # Check module replacement is as expected
        q_module = q_named_modules[orig_name]
        expected_type = expected_type_replacements.get(type(orig_module))
        if expected_type is None or expected_qbits[orig_name] == QBits(
                None, None):
            assert type(orig_module) == type(q_module)
        else:
            assert type(q_module) == expected_type
            if expected_type == DummyWrapperLayer:
                assert expected_qbits[orig_name + '.inner'] == q_module.qbits
            else:
                assert expected_qbits[orig_name] == q_module.qbits
Beispiel #5
0
        torch.equal(q_param, orig_param) for q_param, orig_param in zip(
            model.parameters(), m_orig.parameters()))


def test_overrides_ordered_dict(model):
    with pytest.raises(
            TypeError,
            message=
            'Expecting TypeError when bits_overrides is not an OrderedDict'):
        DummyQuantizer(model, bits_overrides={'testing': '123'})


@pytest.mark.parametrize(
    "qbits, bits_overrides, explicit_expected_overrides",
    [
        (QBits(8, 4), OrderedDict(), {}),
        (QBits(8, 4),
         OrderedDict([('conv1', {
             'acts': None,
             'wts': None
         }), ('relu1', {
             'acts': None,
             'wts': None
         })]), {
             'conv1': QBits(None, None),
             'relu1': QBits(None, None)
         }),
        (QBits(8, 8),
         OrderedDict([('sub.*conv1', {
             'wts': 4
         }), ('sub.*conv2', {