def get_expected_qbits(model, qbits, expected_overrides): expected_qbits = {} post_prepare_changes = {} prefix = 'module.' if isinstance(model, torch.nn.DataParallel) else '' for orig_name, orig_module in model.named_modules(): bits_a, bits_w, bits_b = expected_overrides.get(orig_name.replace(prefix, '', 1), qbits) if not params_quantizable(orig_module): bits_w = bits_b = None expected_qbits[orig_name] = QBits(bits_a, bits_w, bits_b) # We're testing replacement of module with container if isinstance(orig_module, (nn.Conv2d, nn.Linear)): post_prepare_changes[orig_name] = QBits(bits_a, None, None) post_prepare_changes[orig_name + '.inner'] = expected_qbits[orig_name] return expected_qbits, post_prepare_changes
def get_expected_qbits(model, qbits, expected_overrides): expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer, nn.Linear: DummyWrapperLayer} expected_qbits = OrderedDict() post_prepare_qbbits_changes = OrderedDict() post_prepare_expected_types = OrderedDict() prefix = 'module.' if isinstance(model, torch.nn.DataParallel) else '' for orig_name, orig_module in model.named_modules(): orig_module_type = type(orig_module) bits_a, bits_w, bits_b = expected_overrides.get(orig_name.replace(prefix, '', 1), qbits) if not params_quantizable(orig_module): bits_w = bits_b = None expected_qbits[orig_name] = QBits(bits_a, bits_w, bits_b) if expected_qbits[orig_name] == QBits(None, None, None): post_prepare_expected_types[orig_name] = orig_module_type else: post_prepare_expected_types[orig_name] = expected_type_replacements.get(orig_module_type, orig_module_type) # We're testing replacement of module with container if post_prepare_expected_types[orig_name] == DummyWrapperLayer: post_prepare_qbbits_changes[orig_name] = QBits(bits_a, None, None) post_prepare_qbbits_changes[orig_name + '.inner'] = expected_qbits[orig_name] post_prepare_expected_types[orig_name + '.inner'] = orig_module_type return expected_qbits, post_prepare_qbbits_changes, post_prepare_expected_types
def test_overrides_ordered_dict(model): pytest_raises_wrapper(TypeError, 'Expecting TypeError when overrides is not an OrderedDict', DummyQuantizer, model, overrides={'testing': {'testing': '123'}}) acts_key = 'bits_activations' wts_key = 'bits_weights' bias_key = 'bits_bias' @pytest.mark.parametrize( "qbits, overrides, explicit_expected_overrides", [ (QBits(8, 4, 32), OrderedDict(), {}), (QBits(8, 4, 32), OrderedDict([('conv1', {acts_key: None, wts_key: None, bias_key: None}), ('relu1', {acts_key: None, wts_key: None, bias_key: None})]), {'conv1': QBits(None, None, None), 'relu1': QBits(None, None, None)}), (QBits(8, 8, 32), OrderedDict([('sub.*conv1', {wts_key: 4}), ('sub.*conv2', {acts_key: 4, wts_key: 4})]), {'sub1.conv1': QBits(8, 4, 32), 'sub1.conv2': QBits(4, 4, 32), 'sub2.conv1': QBits(8, 4, 32), 'sub2.conv2': QBits(4, 4, 32)}), (QBits(4, 4, 32), OrderedDict([('sub1\..*1', {acts_key: 16, wts_key: 16}), ('sub1\..*', {acts_key: 8, wts_key: 8})]), {'sub1.conv1': QBits(16, 16, 32), 'sub1.bn1': QBits(16, None, None), 'sub1.relu1': QBits(16, None, None), 'sub1.pool1': QBits(16, None, None), 'sub1.conv2': QBits(8, 8, 32), 'sub1.bn2': QBits(8, None, None), 'sub1.relu2': QBits(8, None, None), 'sub1.pool2': QBits(8, None, None)}), (QBits(4, 4, 32), OrderedDict([('sub1\..*', {acts_key: 8, wts_key: 8}), ('sub1\..*1', {acts_key: 16, wts_key: 16})]),
def test_model_prep(model, optimizer, qbits, bits_overrides, explicit_expected_overrides, train_with_fp_copy, quantize_bias, parallel): if parallel: model = torch.nn.DataParallel(model) m_orig = deepcopy(model) # Build expected QBits expected_qbits, post_prepare_changes = get_expected_qbits( model, qbits, explicit_expected_overrides) # Initialize Quantizer q = DummyQuantizer(model, optimizer=optimizer, bits_activations=qbits.acts, bits_weights=qbits.wts, bits_overrides=deepcopy(bits_overrides), train_with_fp_copy=train_with_fp_copy, quantize_bias=quantize_bias) # Check number of bits for quantization were registered correctly assert q.module_qbits_map == expected_qbits q.prepare_model() expected_qbits.update(post_prepare_changes) for ptq in q.params_to_quantize: assert params_quantizable(ptq.module) assert expected_qbits[ptq.module_name].wts is not None # Check parameter names are as expected assert ptq.q_attr_name in ['weight', 'bias'] # Check bias will be quantized only if flag is enabled if ptq.q_attr_name == 'bias': assert quantize_bias named_params = dict(ptq.module.named_parameters()) if q.train_with_fp_copy: # Checking parameter replacement is as expected assert ptq.fp_attr_name == FP_BKP_PREFIX + ptq.q_attr_name assert ptq.fp_attr_name in named_params assert ptq.q_attr_name not in named_params # Making sure the following doesn't throw an exception, # so we know q_attr_name is still a buffer in the module getattr(ptq.module, ptq.q_attr_name) else: # Make sure we didn't screw anything up assert ptq.fp_attr_name == ptq.q_attr_name assert ptq.fp_attr_name in named_params # Check number of bits registered correctly # Bias number of bits is hard-coded to 32 for now... expected_n_bits = 32 if ptq.q_attr_name == 'bias' else expected_qbits[ ptq.module_name].wts assert ptq.num_bits == expected_n_bits q_named_modules = dict(model.named_modules()) orig_named_modules = dict(m_orig.named_modules()) for orig_name, orig_module in orig_named_modules.items(): # Check no module name from original model is missing assert orig_name in q_named_modules # Check module replacement is as expected q_module = q_named_modules[orig_name] expected_type = expected_type_replacements.get(type(orig_module)) if expected_type is None or expected_qbits[orig_name] == QBits( None, None): assert type(orig_module) == type(q_module) else: assert type(q_module) == expected_type if expected_type == DummyWrapperLayer: assert expected_qbits[orig_name + '.inner'] == q_module.qbits else: assert expected_qbits[orig_name] == q_module.qbits
torch.equal(q_param, orig_param) for q_param, orig_param in zip( model.parameters(), m_orig.parameters())) def test_overrides_ordered_dict(model): with pytest.raises( TypeError, message= 'Expecting TypeError when bits_overrides is not an OrderedDict'): DummyQuantizer(model, bits_overrides={'testing': '123'}) @pytest.mark.parametrize( "qbits, bits_overrides, explicit_expected_overrides", [ (QBits(8, 4), OrderedDict(), {}), (QBits(8, 4), OrderedDict([('conv1', { 'acts': None, 'wts': None }), ('relu1', { 'acts': None, 'wts': None })]), { 'conv1': QBits(None, None), 'relu1': QBits(None, None) }), (QBits(8, 8), OrderedDict([('sub.*conv1', { 'wts': 4 }), ('sub.*conv2', {