def test_parameter_from_stats_update(): config.IGNORE_MISSING_KEYS = True linear = nn.Linear(10, 5, bias=False) q_linear = QuantLinear(10, 5, bias=False, weight_quant_type='binary', weight_scaling_impl_type='parameter_from_stats') l_max = linear.weight.abs().max() old_scale = q_linear.quant_weight_scale() old_ql_max = q_linear.weight.abs().max() q_linear.load_state_dict(linear.state_dict()) new_scale = q_linear.quant_weight_scale() new_ql_max = q_linear.weight.abs().max() assert old_scale == old_ql_max assert new_scale == l_max assert new_scale == new_ql_max
def test_parameter_from_stats_state_dict(): q_linear1 = QuantLinear(10, 5, bias=False, weight_quant_type='binary', weight_scaling_impl_type='parameter', weight_scaling_init=0.1) q_linear2 = QuantLinear(10, 5, bias=False, weight_quant_type='binary', weight_scaling_impl_type='parameter', weight_scaling_init=0.001) q_linear1_old_scale = q_linear1.quant_weight_scale() q_linear1.load_state_dict(q_linear2.state_dict()) q_linear1_new_scale = q_linear1.quant_weight_scale() q_linear2_scale = q_linear2.quant_weight_scale() assert q_linear1_old_scale != q_linear2_scale assert q_linear1_old_scale != q_linear1_new_scale assert q_linear1_new_scale == q_linear2_scale
def op_symbolic_kwargs(cls, module: QuantLinear): linear_symbolic_kwargs = { 'input_scale': module.quant_input_scale(), 'input_zero_point': cls.quant_input_zero_point(module), 'int_weight': cls.int_weight(module).t(), 'weight_scale': module.quant_weight_scale(), 'weight_zero_point': cls.quant_weight_zero_point(module), 'output_scale': module.quant_output_scale(), 'output_zero_point': cls.quant_output_zero_point(module), 'output_dtype': cls.torch_8b_dtype(module.is_quant_output_signed), 'out_shape': cls.quant_output_shape(module)} return linear_symbolic_kwargs
def op_symbolic_kwargs(cls, module: QuantLinear): linear_symbolic_kwargs = { 'input_scale': module.quant_input_scale(), 'input_zero_point': cls.quant_input_zero_point(module), 'int_weight': cls.int_weight(module), 'weight_scale': module.quant_weight_scale(), 'weight_zero_point': cls.quant_weight_zero_point(module), 'output_scale': module.quant_output_scale(), 'output_zero_point': cls.quant_output_zero_point(module), 'out_shape': cls.quant_output_shape(module), 'in_features': module.in_features, 'out_features': module.out_features } return linear_symbolic_kwargs
def test_module_init_scale_impl_type_override(self): mod = QuantLinear( out_features=OUTPUT_FEATURES, in_features=INPUT_FEATURES, bias=True, weight_scaling_impl_type='HE') assert mod.quant_weight_scale()