Exemplo n.º 1
0
    def test_quant_different_prec(self, verbose):
        """QuantLSTM vs. manual input quantization + pytorchLSTM."""
        batch = 22
        input_size = 23
        hidden_size = 24
        seq_len = 1
        num_bits_weight = 4
        num_bits_input = 8

        quant_desc_input = tensor_quant.QuantDescriptor(num_bits=num_bits_input)
        quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=num_bits_weight)
        quant_rnn_object = quant_rnn.QuantLSTM(input_size, hidden_size, num_layers=1, bias=False,
                batch_first=False, dropout=0, bidirectional=False,
                quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight)
        ref_rnn_object = nn.LSTM(input_size, hidden_size, num_layers=1, bias=False,
                batch_first=False, dropout=0, bidirectional=False)

        input = torch.randn(seq_len, batch, input_size)
        hidden = torch.randn(seq_len, batch, hidden_size)
        cell = torch.randn(seq_len, batch, hidden_size)

        quant_input, quant_hidden = utils.quantize_by_range_fused((input, hidden), num_bits_input)

        utils.copy_state_and_quantize_fused(ref_rnn_object, quant_rnn_object, num_bits_weight)

        quant_out, (quant_hout, quant_cout) = quant_rnn_object(input, hx=(hidden, cell))
        ref_out, (ref_hout, ref_cout) = ref_rnn_object(quant_input, hx=(quant_hidden, cell))

        utils.compare(quant_out, ref_out)
        utils.compare(quant_hout, ref_hout)
        utils.compare(quant_cout, ref_cout)
Exemplo n.º 2
0
    def test_no_quant_input_hidden(self, verbose):
        """QuantLSTM with quantization disabled vs. pytorch LSTM for input and hidden inputs."""
        batch = 13
        input_size = 19
        hidden_size = 20
        seq_len = 1

        quant_rnn_object = quant_rnn.QuantLSTM(input_size, hidden_size,
                num_layers=1, bias=False, batch_first=False, dropout=0, bidirectional=False)
        quant_rnn_object._input_quantizers[0].disable()
        quant_rnn_object._weight_quantizers[0].disable()
        ref_rnn_object = nn.LSTM(input_size, hidden_size,
                num_layers=1, bias=False, batch_first=False, dropout=0, bidirectional=False)

        # copy weights from one rnn to the other
        ref_rnn_object.load_state_dict(quant_rnn_object.state_dict())

        input = torch.randn(seq_len, batch, input_size)
        hidden = torch.randn(seq_len, batch, hidden_size)
        cell = torch.randn(seq_len, batch, hidden_size)

        quant_out, (quant_hout, quant_cout) = quant_rnn_object(input, hx=(hidden, cell))
        ref_out, (ref_hout, ref_cout) = ref_rnn_object(input, hx=(hidden, cell))

        utils.compare(quant_out, ref_out)
        utils.compare(quant_hout, ref_hout)
        utils.compare(quant_cout, ref_cout)
Exemplo n.º 3
0
        def testcase(input_size, hidden_size, seq_len, batch, num_layers, bias, batch_first, dropout, bidirectional):

            quant_rnn_object = quant_rnn.QuantLSTM(input_size, hidden_size,
                    num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout,
                    bidirectional=bidirectional)

            num_quantizers = num_layers * 2 if bidirectional else num_layers
            for i in range(num_quantizers):
                quant_rnn_object._input_quantizers[i].disable()
                quant_rnn_object._weight_quantizers[i].disable()

            ref_rnn_object = nn.LSTM(input_size, hidden_size,
                    num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout,
                    bidirectional=bidirectional)

            # copy state from one rnn to the other
            ref_rnn_object.load_state_dict(quant_rnn_object.state_dict())

            input = torch.randn(seq_len, batch, input_size)
            num_directions = 2 if bidirectional else 1
            hidden = torch.randn(num_layers*num_directions, batch, hidden_size)
            cell = torch.randn(num_layers*num_directions, batch, hidden_size)

            quant_out, (quant_hout, quant_cout) = quant_rnn_object(input, hx=(hidden, cell))
            ref_out, (ref_hout, ref_cout) = ref_rnn_object(input, hx=(hidden, cell))

            utils.compare(quant_out, ref_out)
            utils.compare(quant_hout, ref_hout)
            utils.compare(quant_cout, ref_cout)
Exemplo n.º 4
0
    def test_against_unquantized(self, verbose):
        """Quantization should introduce bounded error utils.compare to pytorch implementation."""
        batch = 21
        input_size = 33
        hidden_size = 25
        seq_len = 1

        quant_desc_input = tensor_quant.QuantDescriptor(num_bits=16)
        quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=16,
                                                         axis=(1, ))
        quant_rnn_object = quant_rnn.QuantLSTM(
            input_size,
            hidden_size,
            num_layers=1,
            bias=False,
            batch_first=False,
            dropout=0,
            bidirectional=False,
            quant_desc_input=quant_desc_input,
            quant_desc_weight=quant_desc_weight)
        ref_rnn_object = nn.LSTM(input_size,
                                 hidden_size,
                                 num_layers=1,
                                 bias=False,
                                 batch_first=False,
                                 dropout=0,
                                 bidirectional=False)

        # copy weights from one rnn to the other
        ref_rnn_object.load_state_dict(quant_rnn_object.state_dict())

        input = torch.randn(seq_len, batch, input_size)
        hidden = torch.randn(seq_len, batch, hidden_size)
        cell = torch.randn(seq_len, batch, hidden_size)

        quant_out, (quant_hout, quant_cout) = quant_rnn_object(input,
                                                               hx=(hidden,
                                                                   cell))
        ref_out, (ref_hout, ref_cout) = ref_rnn_object(input,
                                                       hx=(hidden, cell))

        # The difference between reference and quantized should be bounded in a range
        # Small values which become 0 after quantization lead to large relative errors. rtol and atol could be
        # much smaller without those values
        utils.compare(quant_out, ref_out, rtol=1e-4, atol=1e-4)
        utils.compare(quant_hout, ref_hout, rtol=1e-4, atol=1e-4)
        utils.compare(quant_cout, ref_cout, rtol=1e-4, atol=1e-4)

        # check that quantization introduces some error
        utils.assert_min_mse(quant_out, ref_out, tol=1e-20)
        utils.assert_min_mse(quant_hout, ref_hout, tol=1e-20)
        utils.assert_min_mse(quant_cout, ref_cout, tol=1e-20)
Exemplo n.º 5
0
    def test_basic_forward(self, verbose):
        """Do a forward pass on the layer module and see if anything catches fire."""
        batch = 5
        input_size = 13
        hidden_size = 31
        seq_len = 1

        quant_desc_input = tensor_quant.QuantDescriptor(num_bits=8)
        quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=8, axis=(1,))
        quant_rnn_object = quant_rnn.QuantLSTM(input_size, hidden_size,
                num_layers=1, bias=False, batch_first=False, dropout=0, bidirectional=False,
                quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight)
        input = torch.randn(seq_len, batch, input_size)
        hidden = torch.randn(seq_len, batch, hidden_size)
        cell = torch.randn(seq_len, batch, hidden_size)
        quant_rnn_object(input, hx=(hidden, cell))