def test_scaling_parameter_from_stats(self): shape = [8, 3, 64, 64] collect_stats_steps = 100 stats_act = QuantReLU( bit_width=BIT_WIDTH, quant_type=QuantType.INT, scaling_impl_type=ScalingImplType.PARAMETER_FROM_STATS, scaling_stats_permute_dims=None, scaling_stats_op=StatsOp.PERCENTILE, collect_stats_steps=collect_stats_steps, scaling_min_val=None, percentile_q=99.0) stats_act.train() tensor_quant = stats_act.act_quant.fused_activation_quant_proxy.tensor_quant scaling_value = tensor_quant.scaling_impl.value for i in range(collect_stats_steps): inp = torch.randn(shape) out = stats_act(inp) out.requires_grad_(True) # i need something to require a grad out.sum().backward() assert scaling_value.grad is None inp = torch.randn(shape) out = stats_act(inp) out.sum().backward() assert scaling_value.grad is not None
def test_scaling_stats_to_parameter(self): stats_act = QuantReLU(bit_width=BIT_WIDTH, max_val=MAX_VAL, quant_type=QuantType.INT, scaling_impl_type=ScalingImplType.STATS, scaling_stats_permute_dims=None, scaling_stats_op=StatsOp.MAX) stats_act.train() for i in range(RANDOM_ITERS): inp = torch.randn([8, 3, 64, 64]) stats_act(inp) stats_state_dict = stats_act.state_dict() param_act = QuantReLU(bit_width=BIT_WIDTH, max_val=MAX_VAL, quant_type=QuantType.INT, scaling_impl_type=ScalingImplType.PARAMETER) param_act.load_state_dict(stats_state_dict) stats_act.eval() param_act.eval() assert (torch.allclose(stats_act.quant_act_scale(), param_act.quant_act_scale()))
def test_scaling_parameter_grad(self): stats_act = QuantReLU(bit_width=BIT_WIDTH, max_val=MAX_VAL, quant_type=QuantType.INT, scaling_impl_type=ScalingImplType.PARAMETER) stats_act.train() for i in range(RANDOM_ITERS): inp = torch.randn([8, 3, 64, 64]) stats_act(inp) out = stats_act(inp) out.sum().backward() tensor_quant = stats_act.act_quant.fused_activation_quant_proxy.tensor_quant scaling_value = tensor_quant.scaling_impl.value assert scaling_value.grad is not None