Example #1
0
    def test_quant_input_hidden(self, verbose):
        """QuantLSTMCell vs. manual input quantization + pytorchLSTMCell."""
        batch = 15
        input_size = 121
        hidden_size = 51
        num_bits = 4

        quant_desc_input = tensor_quant.QuantDescriptor(num_bits=num_bits)
        quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=num_bits)
        quant_rnn_object = quant_rnn.QuantLSTMCell(
            input_size,
            hidden_size,
            bias=False,
            quant_desc_input=quant_desc_input,
            quant_desc_weight=quant_desc_weight)
        ref_rnn_object = nn.LSTMCell(input_size, hidden_size, bias=False)

        input = torch.randn(batch, input_size)
        hidden = torch.randn(batch, hidden_size)
        cell = torch.randn(batch, hidden_size)

        quant_hout, quant_cout = quant_rnn_object(input, hx=(hidden, cell))

        quant_input, quant_hidden = utils.quantize_by_range_fused(
            (input, hidden), num_bits)

        utils.copy_state_and_quantize_fused(ref_rnn_object, quant_rnn_object,
                                            num_bits)

        ref_hout, ref_cout = ref_rnn_object(quant_input,
                                            hx=(quant_hidden, cell))

        utils.compare(quant_hout, ref_hout)
        utils.compare(quant_cout, ref_cout)
    def test_fake_quant_against_unquantized(self):
        """
        Quantized Linear should introduce bounded error compare to Linear
        """
        size_in = 255
        size_out = 257
        test_input = torch.randn(32, size_in).cuda()

        torch.manual_seed(1234)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(1234)
        quant_linear_layer = quant_linear.QuantLinear(
            size_in,
            size_out,
            bias=True,
            quant_desc_input=tensor_quant.QuantDescriptor(num_bits=16),
            quant_desc_weight=tensor_quant.QuantDescriptor(num_bits=16,
                                                           axis=0))

        # Reset seed. Make sure weight and bias are the same
        torch.manual_seed(1234)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(1234)
        linear_layer = nn.Linear(size_in, size_out, bias=True)

        quant_out_features = quant_linear_layer(test_input)
        out_features = linear_layer(test_input)

        # The difference between Linear and QuantLinear should be bounded in a range
        # Small values which become 0 after quantization lead to large relative errors. rtol and atol could be
        # much smaller without those values
        np.testing.assert_allclose(quant_out_features.detach().cpu().numpy(),
                                   out_features.detach().cpu().numpy(),
                                   rtol=0.01,
                                   atol=1e-4)
Example #3
0
    def test_quant_different_prec(self, verbose):
        """QuantLSTM vs. manual input quantization + pytorchLSTM."""
        batch = 22
        input_size = 23
        hidden_size = 24
        seq_len = 1
        num_bits_weight = 4
        num_bits_input = 8

        quant_desc_input = tensor_quant.QuantDescriptor(num_bits=num_bits_input)
        quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=num_bits_weight)
        quant_rnn_object = quant_rnn.QuantLSTM(input_size, hidden_size, num_layers=1, bias=False,
                batch_first=False, dropout=0, bidirectional=False,
                quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight)
        ref_rnn_object = nn.LSTM(input_size, hidden_size, num_layers=1, bias=False,
                batch_first=False, dropout=0, bidirectional=False)

        input = torch.randn(seq_len, batch, input_size)
        hidden = torch.randn(seq_len, batch, hidden_size)
        cell = torch.randn(seq_len, batch, hidden_size)

        quant_input, quant_hidden = utils.quantize_by_range_fused((input, hidden), num_bits_input)

        utils.copy_state_and_quantize_fused(ref_rnn_object, quant_rnn_object, num_bits_weight)

        quant_out, (quant_hout, quant_cout) = quant_rnn_object(input, hx=(hidden, cell))
        ref_out, (ref_hout, ref_cout) = ref_rnn_object(quant_input, hx=(quant_hidden, cell))

        utils.compare(quant_out, ref_out)
        utils.compare(quant_hout, ref_hout)
        utils.compare(quant_cout, ref_cout)
    def test_fake_quant_per_channel_other_precs(self):
        """Test some precisions other than 8bit."""
        size_in = 255
        size_out = 257
        quant_desc_input = tensor_quant.QuantDescriptor(num_bits=4)
        quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=3)
        quant_linear_object = quant_linear.QuantLinear(
            size_in,
            size_out,
            bias=False,
            quant_desc_input=quant_desc_input,
            quant_desc_weight=quant_desc_weight)
        weight_quantizer = TensorQuantizer(quant_desc_weight)
        test_input_quantizer = TensorQuantizer(quant_desc_input)

        test_input = torch.randn(32, size_in)

        weight_copy = quant_linear_object.weight.clone()
        quant_input = test_input_quantizer(test_input)
        quant_weight = weight_quantizer(weight_copy)

        out1 = F.linear(quant_input, quant_weight)
        out2 = quant_linear_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(),
                                      out2.detach().cpu().numpy())
Example #5
0
    def test_against_unquantized(self, verbose):
        """Quantization should introduce bounded error utils.compare to pytorch implementation."""
        batch = 9
        input_size = 13
        hidden_size = 7

        quant_desc_input = tensor_quant.QuantDescriptor(num_bits=16)
        quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=16, axis=(1,))
        quant_rnn_object = quant_rnn.QuantLSTMCell(input_size, hidden_size, bias=False,
                quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight)
        ref_rnn_object = nn.LSTMCell(input_size, hidden_size, bias=False)

        # copy weights from one rnn to the other
        ref_rnn_object.load_state_dict(quant_rnn_object.state_dict())

        input = torch.randn(batch, input_size)
        hidden = torch.randn(batch, hidden_size)
        cell = torch.randn(batch, hidden_size)

        quant_hout, quant_cout = quant_rnn_object(input, hx=(hidden, cell))
        ref_hout, ref_cout = ref_rnn_object(input, hx=(hidden, cell))

        # The difference between reference and quantized should be bounded in a range
        # Small values which become 0 after quantization lead to large relative errors. rtol and atol could be
        # much smaller without those values
        utils.compare(quant_hout, ref_hout, rtol=1e-4, atol=1e-4)
        utils.compare(quant_cout, ref_cout, rtol=1e-4, atol=1e-4)

        # check that quantization introduces some error
        utils.assert_min_mse(quant_hout, ref_hout, tol=1e-20)
        utils.assert_min_mse(quant_cout, ref_cout, tol=1e-20)
    def test_amax(self):
        test_quant_desc = tensor_quant.QuantDescriptor()
        assert test_quant_desc.amax is None

        test_quant_desc = tensor_quant.QuantDescriptor(amax=1.2)
        assert isinstance(test_quant_desc.amax, np.ndarray)
        np.testing.assert_array_equal(test_quant_desc.amax, np.float32(1.2))

        test_quant_desc = tensor_quant.QuantDescriptor(amax=[1.3, 1.4])
        assert isinstance(test_quant_desc.amax, np.ndarray)
        np.testing.assert_array_equal(test_quant_desc.amax,
                                      np.float32([1.3, 1.4]))

        with pytest.raises(TypeError, match="must be float, list or ndarray"):
            tensor_quant.QuantDescriptor(amax='oops')
Example #7
0
 def test_raise(self):
     with pytest.raises(ValueError) as excinfo:
         quant_pooling_object = quant_pooling.QuantAdaptiveAvgPool3d(
             output_size=3,
             quant_desc_input=tensor_quant.QuantDescriptor(
                 fake_quant=False))
     assert "Only fake quantization is supported" in str(excinfo.value)
    def test_from_to_yaml(self):
        quant_desc_1 = tensor_quant.QuantDescriptor(num_bits=2,
                                                    name='a',
                                                    fake_quant=True,
                                                    axis=(1, 2),
                                                    amax=3.1415926536)
        quant_desc_2 = tensor_quant.QuantDescriptor.from_yaml(
            quant_desc_1.to_yaml())
        if verbose:
            print(quant_desc_1.to_yaml())
        assert quant_desc_1 == quant_desc_2

        quant_desc_1 = tensor_quant.QuantDescriptor(num_bits=2, amax=0.1)
        quant_desc_2 = tensor_quant.QuantDescriptor.from_yaml(
            quant_desc_1.to_yaml())
        assert quant_desc_1 == quant_desc_2
    def test_set_default_quant_desc(self):
        quant_linear_layer = quant_linear.QuantLinear(32, 257)
        assert quant_linear_layer.input_quantizer.axis == None
        assert quant_linear_layer.weight_quantizer.axis == (0)

        # set default to a different one
        quant_desc_input = tensor_quant.QuantDescriptor(num_bits=11)
        quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=13, axis=1)
        quant_linear.Linear.set_default_quant_desc_input(quant_desc_input)
        quant_linear.Linear.set_default_quant_desc_weight(quant_desc_weight)

        # Create one with default descriptor
        quant_linear_layer = quant_linear.QuantLinear(32, 257)
        # Check quant_desc in quantizer created with default descriptor
        assert quant_linear_layer.input_quantizer.num_bits == quant_desc_input.num_bits
        assert quant_linear_layer.weight_quantizer.axis == quant_desc_weight.axis
    def test_from_to_dict(self):
        quant_desc_1 = tensor_quant.QuantDescriptor(num_bits=2,
                                                    name='a',
                                                    fake_quant=True,
                                                    axis=(1, 2),
                                                    amax=3.1415926536)
        quant_desc_2 = tensor_quant.QuantDescriptor(**quant_desc_1.dict())
        if verbose:
            print(quant_desc_1.dict())
        assert quant_desc_1 == quant_desc_2

        quant_desc_1 = tensor_quant.QuantDescriptor(num_bits=2,
                                                    amax=0.1,
                                                    unsigned=True)
        quant_desc_2 = tensor_quant.QuantDescriptor(**quant_desc_1.dict())
        assert quant_desc_1 == quant_desc_2
Example #11
0
    def test_per_channel_scale(self, verbose):
        """Quantizer performs per channel scaling"""
        x_np = np.random.rand(15, 15, 64, 128).astype('float32')
        x_torch = torch.Tensor(x_np).cuda()

        # Pytorch filter layout seems to be KCRS, reduce max to shape [K, 1, 1, 1] to test per channel scale
        # Shrink max a little, so that clip behavior is tested
        amax_x_np = 0.7 * np.max(np.abs(x_np), axis=(1, 2, 3), keepdims=True)

        quant_x_np = test_utils.quant_np(x_np, amax_x_np)
        quantizer = tensor_quantizer.TensorQuantizer(
            tensor_quant.QuantDescriptor(num_bits=8, axis=(0), fake_quant=False, scale_amax=0.7))
        quantizer.cuda()
        module_quant_x = quantizer(x_torch)

        # np.testing.assert_array_equal(quant_x_torch.cpu().numpy(), quant_x_np)
        # Pytorch numerics is not the same as numpy, it will be off by 1
        error = np.abs(module_quant_x.cpu().numpy() - quant_x_np)
        np.testing.assert_array_less(error, 2)
        if verbose:
            mismatches = np.where(error >= 1)
            print("Mismatches:")
            print(" Original: ", x_np[mismatches])
            print(" numpy: ", quant_x_np[mismatches])
            print(" TensorQuantizer: ", module_quant_x.cpu().numpy()[mismatches])
Example #12
0
    def test_basic_forward(self, verbose):
        """Do a forward pass on the layer module and see if anything catches fire."""
        batch = 5
        input_size = 13
        hidden_size = 31
        seq_len = 1

        quant_desc_input = tensor_quant.QuantDescriptor(num_bits=8)
        quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=8, axis=(1,))
        quant_rnn_object = quant_rnn.QuantLSTM(input_size, hidden_size,
                num_layers=1, bias=False, batch_first=False, dropout=0, bidirectional=False,
                quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight)
        input = torch.randn(seq_len, batch, input_size)
        hidden = torch.randn(seq_len, batch, hidden_size)
        cell = torch.randn(seq_len, batch, hidden_size)
        quant_rnn_object(input, hx=(hidden, cell))
Example #13
0
 def test_per_tensor_scale(self):
     """Quantizer performs expected quantization"""
     x_np = np.random.rand(1023)
     x_torch = torch.Tensor(x_np)
     quant_x_np = test_utils.quant_np(x_np, np.max(np.abs(x_np)))
     quantizer = tensor_quantizer.TensorQuantizer(tensor_quant.QuantDescriptor(num_bits=8, fake_quant=False))
     module_quant_x = quantizer(x_torch)
     np.testing.assert_array_equal(module_quant_x.cpu().numpy(), quant_x_np)
Example #14
0
 def test_raise(self):
     with pytest.raises(ValueError) as excinfo:
         quant_pooling_object = quant_pooling.QuantMaxPool2d(
             kernel_size=3,
             stride=1,
             quant_desc_input=tensor_quant.QuantDescriptor(
                 fake_quant=False))
     assert "Only fake quantization is supported" in str(excinfo.value)
 def test_raise(self):
     with pytest.raises(ValueError) as excinfo:
         quant_linear_object = quant_linear.QuantLinear(
             7,
             9,
             bias=False,
             quant_desc_weight=tensor_quant.QuantDescriptor(
                 fake_quant=False))
     assert "Only fake quantization is supported" in str(excinfo.value)
Example #16
0
    def test_basic_forward(self, verbose):
        """Do a forward pass on the cell module and see if anything catches fire."""
        batch = 7
        input_size = 11
        hidden_size = 9

        quant_desc_input = tensor_quant.QuantDescriptor(num_bits=8)
        quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=8, axis=(1,))
        quant_rnn_object = quant_rnn.QuantLSTMCell(input_size, hidden_size, bias=False,
                quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight)
        quant_rnn_object._input_quantizer.disable()
        quant_rnn_object._weight_quantizer.disable()

        input = torch.randn(batch, input_size)
        hidden = torch.randn(batch, hidden_size)
        cell = torch.randn(batch, hidden_size)

        quant_rnn_object(input, hx=(hidden, cell))
Example #17
0
    def test_state_loading(self):
        """Test quant_desc loading via state_dict"""
        amax = [3.142, 2.718]
        quant_desc1 = tensor_quant.QuantDescriptor(amax=amax)
        quantizer1 = tensor_quantizer.TensorQuantizer(quant_desc1)

        # copy state
        quantizer1.load_state_dict(quantizer1.state_dict())
        np.testing.assert_array_equal(quantizer1.amax.detach().cpu().numpy(), quant_desc1.amax)
Example #18
0
    def test_init_calib(self):
        quant_desc2 = tensor_quant.QuantDescriptor(axis=(0, 1))
        quantizer2 = tensor_quantizer.TensorQuantizer(quant_desc2, if_calib=True).cuda()

        x_2 = torch.rand(127, 63, 7, 7).cuda()
        quantizer2(x_2)
        quantizer2.load_calib_amax()

        assert quantizer2.amax.numel() == 127 * 63
Example #19
0
    def test_properties(self):
        quant_desc1 = tensor_quant.QuantDescriptor(amax=3.14)
        quantizer1 = tensor_quantizer.TensorQuantizer(quant_desc1)
        quantizer1.amax = 0.577

        assert quantizer1.amax.detach().cpu().numpy() == np.float32(0.577)
        np.testing.assert_array_equal(quantizer1.amax.detach().cpu().numpy(), quantizer1.amax)
        assert quantizer1.step_size == 0.577 / 127.

        quant_desc2 = tensor_quant.QuantDescriptor()
        quantizer2 = tensor_quantizer.TensorQuantizer(quant_desc2)
        amax_np = np.array([3.142, 2.718], dtype=np.float32)
        quantizer2.amax = amax_np
        np.testing.assert_array_equal(quantizer2.amax.detach().cpu().numpy(), amax_np)

        quant_desc3 = tensor_quant.QuantDescriptor()
        quantizer3 = tensor_quantizer.TensorQuantizer(quant_desc3)
        assert quantizer3.amax is None
Example #20
0
 def test_simple_run_no_fake(self):
     """Quantizer fake_quant=False calls tensor_quant and sets the scale property"""
     x = torch.randn(3, 7).cuda()
     amax_x = torch.max(torch.abs(x))
     fn_quant_x, fn_scale = tensor_quant.tensor_quant(x, amax_x)
     quantizer = tensor_quantizer.TensorQuantizer(tensor_quant.QuantDescriptor(num_bits=8, fake_quant=False))
     module_quant_x = quantizer(x)
     module_scale = quantizer.scale
     np.testing.assert_array_equal(fn_quant_x.cpu().numpy(), module_quant_x.cpu().numpy())
     np.testing.assert_array_equal(fn_scale.cpu().numpy(), module_scale.cpu().numpy())
Example #21
0
 def test_clip_mode(self):
     """Test the clip stage only"""
     x_np = np.random.rand(1023).astype(np.float32)
     x_torch = torch.Tensor(x_np)
     amax = 0.5
     clip_x_np = np.clip(x_np, -amax, amax)
     quantizer = tensor_quantizer.TensorQuantizer(
         tensor_quant.QuantDescriptor(amax=amax, learn_amax=True), if_quant=False, if_clip=True)
     assert hasattr(quantizer, 'clip')
     module_clip_x = quantizer(x_torch)
     np.testing.assert_array_equal(module_clip_x.cpu().detach().numpy(), clip_x_np)
Example #22
0
 def test_learn_amax(self):
     """Test the clip implied by learn_amax"""
     x_np = np.random.rand(1023).astype(np.float32)
     x_torch = torch.Tensor(x_np)
     amax = 0.5
     quant_x_np = test_utils.quant_np(x_np, 0.5, fake=True)
     quantizer = tensor_quantizer.TensorQuantizer(
         tensor_quant.QuantDescriptor(num_bits=8, amax=amax, learn_amax=True))
     assert hasattr(quantizer, 'clip')
     module_quant_x = quantizer(x_torch)
     np.testing.assert_array_equal(module_quant_x.cpu().detach().numpy(), quant_x_np)
Example #23
0
    def test_max_calib(self):
        axis = 0
        reduce_axis = (1, 2, 3)
        quant_desc1 = tensor_quant.QuantDescriptor(axis=axis)
        quantizer1 = tensor_quantizer.TensorQuantizer(quant_desc1).cuda()
        quantizer1.enable_calib()

        quant_desc1 = tensor_quant.QuantDescriptor(axis=axis)
        quantizer1 = tensor_quantizer.TensorQuantizer(quant_desc1).cuda()
        quantizer1.enable_calib()

        with pytest.raises(RuntimeError, match="Calibrator returned None"):
            quantizer1.load_calib_amax()

        x_1 = torch.rand(127, 63, 7, 7).cuda()
        x_2 = torch.rand(127, 63, 7, 7).cuda()
        quantizer1(x_1)
        quantizer1(x_2)
        quantizer1.disable_calib()

        global_amax = torch.max(
            quant_utils.reduce_amax(x_1, axis=reduce_axis, keepdims=True),
            quant_utils.reduce_amax(x_2, axis=reduce_axis, keepdims=True))
        test_utils.compare(quantizer1._calibrator.compute_amax(), global_amax, atol=0, rtol=0, ctol=0)

        quantizer1.load_calib_amax()
        test_utils.compare(quantizer1.amax, global_amax, atol=0, rtol=0, ctol=0)

        quant_desc2 = tensor_quant.QuantDescriptor(learn_amax=True)
        quantizer2 = tensor_quantizer.TensorQuantizer(quant_desc2).cuda()
        quantizer2.enable_calib()
        quantizer2(x_1)
        quantizer2(x_2)

        quantizer2.load_calib_amax()
        quantizer2.init_learn_amax()
        test_utils.compare(quantizer2.clip.clip_value_min, -torch.max(global_amax), atol=0, rtol=0, ctol=0)
        test_utils.compare(quantizer2.clip.clip_value_max, torch.max(global_amax), atol=0, rtol=0, ctol=0)
Example #24
0
    def test_scale_amax(self):
        x_np = np.random.rand(1023).astype(np.float32)
        x_torch = torch.Tensor(x_np)
        amax = 0.5
        scale_amax = 0.9
        quant_x_np = test_utils.quant_np(x_np, amax * scale_amax, fake=True)
        quantizer = tensor_quantizer.TensorQuantizer(
            tensor_quant.QuantDescriptor(num_bits=8, amax=amax, scale_amax=scale_amax))
        module_quant_x = quantizer(x_torch)
        np.testing.assert_array_equal(module_quant_x.cpu().detach().numpy(), quant_x_np)

        # Test twice. There was a but in scale amax logic that modify the amax every time
        module_quant_x = quantizer(x_torch)
        np.testing.assert_array_equal(module_quant_x.cpu().detach().numpy(), quant_x_np)
Example #25
0
    def test_input_variable_bits(self):
        # Repeat checking the output for variable number of bits to QuantDescriptor
        for bits in [2, 4, 6]:
            quant_desc_input = tensor_quant.QuantDescriptor(num_bits=bits)

            quant_pooling.QuantMaxPool2d.set_default_quant_desc_input(
                quant_desc_input)
            quant_pooling_object = quant_pooling.QuantMaxPool2d(kernel_size=3,
                                                                stride=1)

            test_input = torch.randn(1, 5, 5, 5, dtype=torch.double)

            quant_input = tensor_quant.fake_tensor_quant(
                test_input, torch.max(torch.abs(test_input)), bits)

            out1 = F.max_pool2d(quant_input, 3, 1, 0, 1, False, False)
            out2 = quant_pooling_object(test_input)
            np.testing.assert_array_equal(out1.detach().cpu().numpy(),
                                          out2.detach().cpu().numpy())
    def test_weight_fake_per_tensor(self):
        with torch.cuda.device(0):
            size = 256
            quant_linear_object = quant_linear.QuantLinear(
                size,
                size,
                bias=False,
                quant_desc_weight=tensor_quant.QuantDescriptor(axis=None))
            quant_linear_object.input_quantizer.disable()
            test_input = torch.randn(size, size)

            weight_copy = quant_linear_object.weight.clone()
            quant_weight = tensor_quant.fake_tensor_quant(
                weight_copy, torch.max(torch.abs(weight_copy)))

            out1 = F.linear(test_input, quant_weight)
            out2 = quant_linear_object(test_input)
            np.testing.assert_array_equal(out1.detach().cpu().numpy(),
                                          out2.detach().cpu().numpy())
Example #27
0
    def test_entropy_and_percentile_calib(self):
        """Don't really have a good way to test it."""
        quant_desc1 = tensor_quant.QuantDescriptor(calib_method='histogram')
        quantizer1 = tensor_quantizer.TensorQuantizer(quant_desc1, if_calib=True, if_quant=False).cuda()

        x_1 = torch.rand(3, 63, 7, 7).cuda()
        x_2 = torch.rand(3, 63, 7, 7).cuda()
        quantizer1(x_1)
        quantizer1(x_2)

        quantizer1.load_calib_amax("entropy")
        test_utils.compare(quantizer1._calibrator.compute_amax("entropy"), quantizer1.amax, atol=0, rtol=0, ctol=0)
        quantizer1._calibrator.reset()

        quantizer1(x_1)
        quantizer1(x_2)

        quantizer1.load_calib_amax("percentile", percentile=99.99)
        test_utils.compare(quantizer1._calibrator.compute_amax(
            "percentile", percentile=99.99), quantizer1.amax, atol=0, rtol=0, ctol=0)
    def test_scaled_mode(self):
        num_bits = np.random.randint(0, 16)

        test_quant_desc = tensor_quant.QuantDescriptor(num_bits=num_bits)
        assert test_quant_desc.num_bits == num_bits
        assert test_quant_desc.axis is None
        assert test_quant_desc.amax is None
        assert not test_quant_desc.learn_amax

        axis = (0, 1, 3)
        test_quant_desc = tensor_quant.QuantDescriptor(axis=axis)
        assert test_quant_desc.num_bits == 8  # default value
        assert test_quant_desc.axis == axis
        assert test_quant_desc.amax is None

        amax = 0.7
        test_quant_desc = tensor_quant.QuantDescriptor(amax=amax,
                                                       unsigned=True)
        assert test_quant_desc.axis is None
        assert test_quant_desc.amax == np.float32(amax)
        assert test_quant_desc.unsigned

        amax = 0.7
        test_quant_desc = tensor_quant.QuantDescriptor(amax=amax,
                                                       learn_amax=True)
        assert test_quant_desc.amax == np.float32(amax)
        assert test_quant_desc.learn_amax

        # Test the print string once if verbose is set.
        if verbose:
            print(test_quant_desc)

        with pytest.raises(TypeError, match="must be float, list or ndarray"):
            tensor_quant.QuantDescriptor(amax='oops')

        with pytest.raises(TypeError,
                           match="amax must be float, list or ndarray"):
            tensor_quant.QuantDescriptor(amax='oops', learn_amax=True)

        with pytest.raises(TypeError,
                           match="axis is ignored and must be None"):
            tensor_quant.QuantDescriptor(axis=(1, 2),
                                         amax=0.7,
                                         learn_amax=True)
    def test_fake_quant_per_tensor(self):
        """quantize everything, activations will scaled per tensor in ALL cases"""
        size_in = 255
        size_out = 257
        quant_linear_object = quant_linear.QuantLinear(
            size_in,
            size_out,
            bias=False,
            quant_desc_weight=tensor_quant.QuantDescriptor())
        test_input = torch.randn(32, size_in)

        weight_copy = quant_linear_object.weight.clone()
        quant_input = tensor_quant.fake_tensor_quant(
            test_input, torch.max(torch.abs(test_input)))
        quant_weight = tensor_quant.fake_tensor_quant(
            weight_copy, torch.max(torch.abs(weight_copy)))

        out1 = F.linear(quant_input, quant_weight)
        out2 = quant_linear_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(),
                                      out2.detach().cpu().numpy())