def test_quant_different_prec(self, verbose): """QuantLSTM vs. manual input quantization + pytorchLSTM.""" batch = 22 input_size = 23 hidden_size = 24 seq_len = 1 num_bits_weight = 4 num_bits_input = 8 quant_desc_input = tensor_quant.QuantDescriptor(num_bits=num_bits_input) quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=num_bits_weight) quant_rnn_object = quant_rnn.QuantLSTM(input_size, hidden_size, num_layers=1, bias=False, batch_first=False, dropout=0, bidirectional=False, quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight) ref_rnn_object = nn.LSTM(input_size, hidden_size, num_layers=1, bias=False, batch_first=False, dropout=0, bidirectional=False) input = torch.randn(seq_len, batch, input_size) hidden = torch.randn(seq_len, batch, hidden_size) cell = torch.randn(seq_len, batch, hidden_size) quant_input, quant_hidden = utils.quantize_by_range_fused((input, hidden), num_bits_input) utils.copy_state_and_quantize_fused(ref_rnn_object, quant_rnn_object, num_bits_weight) quant_out, (quant_hout, quant_cout) = quant_rnn_object(input, hx=(hidden, cell)) ref_out, (ref_hout, ref_cout) = ref_rnn_object(quant_input, hx=(quant_hidden, cell)) utils.compare(quant_out, ref_out) utils.compare(quant_hout, ref_hout) utils.compare(quant_cout, ref_cout)
def test_no_quant_input_hidden(self, verbose): """QuantLSTM with quantization disabled vs. pytorch LSTM for input and hidden inputs.""" batch = 13 input_size = 19 hidden_size = 20 seq_len = 1 quant_rnn_object = quant_rnn.QuantLSTM(input_size, hidden_size, num_layers=1, bias=False, batch_first=False, dropout=0, bidirectional=False) quant_rnn_object._input_quantizers[0].disable() quant_rnn_object._weight_quantizers[0].disable() ref_rnn_object = nn.LSTM(input_size, hidden_size, num_layers=1, bias=False, batch_first=False, dropout=0, bidirectional=False) # copy weights from one rnn to the other ref_rnn_object.load_state_dict(quant_rnn_object.state_dict()) input = torch.randn(seq_len, batch, input_size) hidden = torch.randn(seq_len, batch, hidden_size) cell = torch.randn(seq_len, batch, hidden_size) quant_out, (quant_hout, quant_cout) = quant_rnn_object(input, hx=(hidden, cell)) ref_out, (ref_hout, ref_cout) = ref_rnn_object(input, hx=(hidden, cell)) utils.compare(quant_out, ref_out) utils.compare(quant_hout, ref_hout) utils.compare(quant_cout, ref_cout)
def testcase(input_size, hidden_size, seq_len, batch, num_layers, bias, batch_first, dropout, bidirectional): quant_rnn_object = quant_rnn.QuantLSTM(input_size, hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) num_quantizers = num_layers * 2 if bidirectional else num_layers for i in range(num_quantizers): quant_rnn_object._input_quantizers[i].disable() quant_rnn_object._weight_quantizers[i].disable() ref_rnn_object = nn.LSTM(input_size, hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) # copy state from one rnn to the other ref_rnn_object.load_state_dict(quant_rnn_object.state_dict()) input = torch.randn(seq_len, batch, input_size) num_directions = 2 if bidirectional else 1 hidden = torch.randn(num_layers*num_directions, batch, hidden_size) cell = torch.randn(num_layers*num_directions, batch, hidden_size) quant_out, (quant_hout, quant_cout) = quant_rnn_object(input, hx=(hidden, cell)) ref_out, (ref_hout, ref_cout) = ref_rnn_object(input, hx=(hidden, cell)) utils.compare(quant_out, ref_out) utils.compare(quant_hout, ref_hout) utils.compare(quant_cout, ref_cout)
def test_against_unquantized(self, verbose): """Quantization should introduce bounded error utils.compare to pytorch implementation.""" batch = 21 input_size = 33 hidden_size = 25 seq_len = 1 quant_desc_input = tensor_quant.QuantDescriptor(num_bits=16) quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=16, axis=(1, )) quant_rnn_object = quant_rnn.QuantLSTM( input_size, hidden_size, num_layers=1, bias=False, batch_first=False, dropout=0, bidirectional=False, quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight) ref_rnn_object = nn.LSTM(input_size, hidden_size, num_layers=1, bias=False, batch_first=False, dropout=0, bidirectional=False) # copy weights from one rnn to the other ref_rnn_object.load_state_dict(quant_rnn_object.state_dict()) input = torch.randn(seq_len, batch, input_size) hidden = torch.randn(seq_len, batch, hidden_size) cell = torch.randn(seq_len, batch, hidden_size) quant_out, (quant_hout, quant_cout) = quant_rnn_object(input, hx=(hidden, cell)) ref_out, (ref_hout, ref_cout) = ref_rnn_object(input, hx=(hidden, cell)) # The difference between reference and quantized should be bounded in a range # Small values which become 0 after quantization lead to large relative errors. rtol and atol could be # much smaller without those values utils.compare(quant_out, ref_out, rtol=1e-4, atol=1e-4) utils.compare(quant_hout, ref_hout, rtol=1e-4, atol=1e-4) utils.compare(quant_cout, ref_cout, rtol=1e-4, atol=1e-4) # check that quantization introduces some error utils.assert_min_mse(quant_out, ref_out, tol=1e-20) utils.assert_min_mse(quant_hout, ref_hout, tol=1e-20) utils.assert_min_mse(quant_cout, ref_cout, tol=1e-20)
def test_basic_forward(self, verbose): """Do a forward pass on the layer module and see if anything catches fire.""" batch = 5 input_size = 13 hidden_size = 31 seq_len = 1 quant_desc_input = tensor_quant.QuantDescriptor(num_bits=8) quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=8, axis=(1,)) quant_rnn_object = quant_rnn.QuantLSTM(input_size, hidden_size, num_layers=1, bias=False, batch_first=False, dropout=0, bidirectional=False, quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight) input = torch.randn(seq_len, batch, input_size) hidden = torch.randn(seq_len, batch, hidden_size) cell = torch.randn(seq_len, batch, hidden_size) quant_rnn_object(input, hx=(hidden, cell))