def test_int(self): shape = (BATCH, IN_CHANNEL, HEIGHT) input_quant_int, input_quant = generate_quant_input( shape, BIT, SCALE, True, True) mod = QuantConv1d(in_channels=IN_CHANNEL, out_channels=OUT_CHANNEL, kernel_size=KERNEL, stride=STRIDE, weight_quant_type=QuantType.INT, weight_bit_width=BIT, bias=False) results_float_quantized = mod(input_quant) weight_int = mod.int_weight() bias = mod.bias results_int_quantized = mod.conv1d_zeros_pad(input_quant_int, weight_int.float(), bias) total_scale = SCALE * mod.quant_weight_scale() result_rescaled = torch.round(results_float_quantized / total_scale) assert (torch.allclose(results_int_quantized, result_rescaled, atol=ATOL, rtol=RTOL))
def quant_weight_scale(module: QuantConv1d): quant_weight_scale = module.quant_weight_scale().type( torch.FloatTensor).detach() if len(quant_weight_scale.shape) == 3: quant_weight_scale = quant_weight_scale.view(1, -1, 1) return quant_weight_scale