Ejemplo n.º 1
0
    def test_int(self):
        shape = (BATCH, IN_CHANNEL, HEIGHT)
        input_quant_int, input_quant = generate_quant_input(
            shape, BIT, SCALE, True, True)
        mod = QuantConv1d(in_channels=IN_CHANNEL,
                          out_channels=OUT_CHANNEL,
                          kernel_size=KERNEL,
                          stride=STRIDE,
                          weight_quant_type=QuantType.INT,
                          weight_bit_width=BIT,
                          bias=False)

        results_float_quantized = mod(input_quant)
        weight_int = mod.int_weight()
        bias = mod.bias
        results_int_quantized = mod.conv1d_zeros_pad(input_quant_int,
                                                     weight_int.float(), bias)
        total_scale = SCALE * mod.quant_weight_scale()
        result_rescaled = torch.round(results_float_quantized / total_scale)
        assert (torch.allclose(results_int_quantized,
                               result_rescaled,
                               atol=ATOL,
                               rtol=RTOL))
Ejemplo n.º 2
0
 def quant_weight_scale(module: QuantConv1d):
     quant_weight_scale = module.quant_weight_scale().type(
         torch.FloatTensor).detach()
     if len(quant_weight_scale.shape) == 3:
         quant_weight_scale = quant_weight_scale.view(1, -1, 1)
     return quant_weight_scale