示例#1
0
def test_parameter_from_stats_update():
    config.IGNORE_MISSING_KEYS = True
    linear = nn.Linear(10, 5, bias=False)
    q_linear = QuantLinear(10,
                           5,
                           bias=False,
                           weight_quant_type='binary',
                           weight_scaling_impl_type='parameter_from_stats')
    l_max = linear.weight.abs().max()
    old_scale = q_linear.quant_weight_scale()
    old_ql_max = q_linear.weight.abs().max()
    q_linear.load_state_dict(linear.state_dict())
    new_scale = q_linear.quant_weight_scale()
    new_ql_max = q_linear.weight.abs().max()
    assert old_scale == old_ql_max
    assert new_scale == l_max
    assert new_scale == new_ql_max
示例#2
0
def test_parameter_from_stats_state_dict():
    q_linear1 = QuantLinear(10,
                            5,
                            bias=False,
                            weight_quant_type='binary',
                            weight_scaling_impl_type='parameter',
                            weight_scaling_init=0.1)
    q_linear2 = QuantLinear(10,
                            5,
                            bias=False,
                            weight_quant_type='binary',
                            weight_scaling_impl_type='parameter',
                            weight_scaling_init=0.001)
    q_linear1_old_scale = q_linear1.quant_weight_scale()
    q_linear1.load_state_dict(q_linear2.state_dict())
    q_linear1_new_scale = q_linear1.quant_weight_scale()
    q_linear2_scale = q_linear2.quant_weight_scale()
    assert q_linear1_old_scale != q_linear2_scale
    assert q_linear1_old_scale != q_linear1_new_scale
    assert q_linear1_new_scale == q_linear2_scale
示例#3
0
 def op_symbolic_kwargs(cls, module: QuantLinear):
     linear_symbolic_kwargs = {
         'input_scale': module.quant_input_scale(),
         'input_zero_point': cls.quant_input_zero_point(module),
         'int_weight': cls.int_weight(module).t(),
         'weight_scale': module.quant_weight_scale(),
         'weight_zero_point': cls.quant_weight_zero_point(module),
         'output_scale': module.quant_output_scale(),
         'output_zero_point': cls.quant_output_zero_point(module),
         'output_dtype': cls.torch_8b_dtype(module.is_quant_output_signed),
         'out_shape': cls.quant_output_shape(module)}
     return linear_symbolic_kwargs
示例#4
0
 def op_symbolic_kwargs(cls, module: QuantLinear):
     linear_symbolic_kwargs = {
         'input_scale': module.quant_input_scale(),
         'input_zero_point': cls.quant_input_zero_point(module),
         'int_weight': cls.int_weight(module),
         'weight_scale': module.quant_weight_scale(),
         'weight_zero_point': cls.quant_weight_zero_point(module),
         'output_scale': module.quant_output_scale(),
         'output_zero_point': cls.quant_output_zero_point(module),
         'out_shape': cls.quant_output_shape(module),
         'in_features': module.in_features,
         'out_features': module.out_features
     }
     return linear_symbolic_kwargs
示例#5
0
 def test_module_init_scale_impl_type_override(self):
     mod = QuantLinear(
         out_features=OUTPUT_FEATURES,
         in_features=INPUT_FEATURES,
         bias=True, weight_scaling_impl_type='HE')
     assert mod.quant_weight_scale()