コード例 #1
0
    def test_no_quant_input_hidden(self, verbose):
        """QuantLSTM with quantization disabled vs. pytorch LSTM for input and hidden inputs."""
        batch = 17
        input_size = 13
        hidden_size = 7

        quant_rnn_object = quant_rnn.QuantLSTMCell(input_size,
                                                   hidden_size,
                                                   bias=False)
        quant_rnn_object._input_quantizer.disable()
        quant_rnn_object._weight_quantizer.disable()
        ref_rnn_object = nn.LSTMCell(input_size, hidden_size, bias=False)

        # copy weights from one rnn to the other
        ref_rnn_object.load_state_dict(quant_rnn_object.state_dict())

        input = torch.randn(batch, input_size)
        hidden = torch.randn(batch, hidden_size)
        cell = torch.randn(batch, hidden_size)

        quant_hout, quant_cout = quant_rnn_object(input, hx=(hidden, cell))
        ref_hout, ref_cout = ref_rnn_object(input, hx=(hidden, cell))

        utils.compare(quant_hout, ref_hout)
        utils.compare(quant_cout, ref_cout)
コード例 #2
0
    def test_quant_input_hidden(self, verbose):
        """QuantLSTMCell vs. manual input quantization + pytorchLSTMCell."""
        batch = 15
        input_size = 121
        hidden_size = 51
        num_bits = 4

        quant_desc_input = tensor_quant.QuantDescriptor(num_bits=num_bits)
        quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=num_bits)
        quant_rnn_object = quant_rnn.QuantLSTMCell(
            input_size,
            hidden_size,
            bias=False,
            quant_desc_input=quant_desc_input,
            quant_desc_weight=quant_desc_weight)
        ref_rnn_object = nn.LSTMCell(input_size, hidden_size, bias=False)

        input = torch.randn(batch, input_size)
        hidden = torch.randn(batch, hidden_size)
        cell = torch.randn(batch, hidden_size)

        quant_hout, quant_cout = quant_rnn_object(input, hx=(hidden, cell))

        quant_input, quant_hidden = utils.quantize_by_range_fused(
            (input, hidden), num_bits)

        utils.copy_state_and_quantize_fused(ref_rnn_object, quant_rnn_object,
                                            num_bits)

        ref_hout, ref_cout = ref_rnn_object(quant_input,
                                            hx=(quant_hidden, cell))

        utils.compare(quant_hout, ref_hout)
        utils.compare(quant_cout, ref_cout)
コード例 #3
0
    def test_against_unquantized(self, verbose):
        """Quantization should introduce bounded error utils.compare to pytorch implementation."""
        batch = 9
        input_size = 13
        hidden_size = 7

        quant_desc_input = tensor_quant.QuantDescriptor(num_bits=16)
        quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=16, axis=(1,))
        quant_rnn_object = quant_rnn.QuantLSTMCell(input_size, hidden_size, bias=False,
                quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight)
        ref_rnn_object = nn.LSTMCell(input_size, hidden_size, bias=False)

        # copy weights from one rnn to the other
        ref_rnn_object.load_state_dict(quant_rnn_object.state_dict())

        input = torch.randn(batch, input_size)
        hidden = torch.randn(batch, hidden_size)
        cell = torch.randn(batch, hidden_size)

        quant_hout, quant_cout = quant_rnn_object(input, hx=(hidden, cell))
        ref_hout, ref_cout = ref_rnn_object(input, hx=(hidden, cell))

        # The difference between reference and quantized should be bounded in a range
        # Small values which become 0 after quantization lead to large relative errors. rtol and atol could be
        # much smaller without those values
        utils.compare(quant_hout, ref_hout, rtol=1e-4, atol=1e-4)
        utils.compare(quant_cout, ref_cout, rtol=1e-4, atol=1e-4)

        # check that quantization introduces some error
        utils.assert_min_mse(quant_hout, ref_hout, tol=1e-20)
        utils.assert_min_mse(quant_cout, ref_cout, tol=1e-20)
コード例 #4
0
    def test_basic_forward(self, verbose):
        """Do a forward pass on the cell module and see if anything catches fire."""
        batch = 7
        input_size = 11
        hidden_size = 9

        quant_desc_input = tensor_quant.QuantDescriptor(num_bits=8)
        quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=8, axis=(1,))
        quant_rnn_object = quant_rnn.QuantLSTMCell(input_size, hidden_size, bias=False,
                quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight)
        quant_rnn_object._input_quantizer.disable()
        quant_rnn_object._weight_quantizer.disable()

        input = torch.randn(batch, input_size)
        hidden = torch.randn(batch, hidden_size)
        cell = torch.randn(batch, hidden_size)

        quant_rnn_object(input, hx=(hidden, cell))