def test_no_quant_input_hidden(self, verbose): """QuantLSTM with quantization disabled vs. pytorch LSTM for input and hidden inputs.""" batch = 17 input_size = 13 hidden_size = 7 quant_rnn_object = quant_rnn.QuantLSTMCell(input_size, hidden_size, bias=False) quant_rnn_object._input_quantizer.disable() quant_rnn_object._weight_quantizer.disable() ref_rnn_object = nn.LSTMCell(input_size, hidden_size, bias=False) # copy weights from one rnn to the other ref_rnn_object.load_state_dict(quant_rnn_object.state_dict()) input = torch.randn(batch, input_size) hidden = torch.randn(batch, hidden_size) cell = torch.randn(batch, hidden_size) quant_hout, quant_cout = quant_rnn_object(input, hx=(hidden, cell)) ref_hout, ref_cout = ref_rnn_object(input, hx=(hidden, cell)) utils.compare(quant_hout, ref_hout) utils.compare(quant_cout, ref_cout)
def test_quant_input_hidden(self, verbose): """QuantLSTMCell vs. manual input quantization + pytorchLSTMCell.""" batch = 15 input_size = 121 hidden_size = 51 num_bits = 4 quant_desc_input = tensor_quant.QuantDescriptor(num_bits=num_bits) quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=num_bits) quant_rnn_object = quant_rnn.QuantLSTMCell( input_size, hidden_size, bias=False, quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight) ref_rnn_object = nn.LSTMCell(input_size, hidden_size, bias=False) input = torch.randn(batch, input_size) hidden = torch.randn(batch, hidden_size) cell = torch.randn(batch, hidden_size) quant_hout, quant_cout = quant_rnn_object(input, hx=(hidden, cell)) quant_input, quant_hidden = utils.quantize_by_range_fused( (input, hidden), num_bits) utils.copy_state_and_quantize_fused(ref_rnn_object, quant_rnn_object, num_bits) ref_hout, ref_cout = ref_rnn_object(quant_input, hx=(quant_hidden, cell)) utils.compare(quant_hout, ref_hout) utils.compare(quant_cout, ref_cout)
def test_against_unquantized(self, verbose): """Quantization should introduce bounded error utils.compare to pytorch implementation.""" batch = 9 input_size = 13 hidden_size = 7 quant_desc_input = tensor_quant.QuantDescriptor(num_bits=16) quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=16, axis=(1,)) quant_rnn_object = quant_rnn.QuantLSTMCell(input_size, hidden_size, bias=False, quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight) ref_rnn_object = nn.LSTMCell(input_size, hidden_size, bias=False) # copy weights from one rnn to the other ref_rnn_object.load_state_dict(quant_rnn_object.state_dict()) input = torch.randn(batch, input_size) hidden = torch.randn(batch, hidden_size) cell = torch.randn(batch, hidden_size) quant_hout, quant_cout = quant_rnn_object(input, hx=(hidden, cell)) ref_hout, ref_cout = ref_rnn_object(input, hx=(hidden, cell)) # The difference between reference and quantized should be bounded in a range # Small values which become 0 after quantization lead to large relative errors. rtol and atol could be # much smaller without those values utils.compare(quant_hout, ref_hout, rtol=1e-4, atol=1e-4) utils.compare(quant_cout, ref_cout, rtol=1e-4, atol=1e-4) # check that quantization introduces some error utils.assert_min_mse(quant_hout, ref_hout, tol=1e-20) utils.assert_min_mse(quant_cout, ref_cout, tol=1e-20)
def test_basic_forward(self, verbose): """Do a forward pass on the cell module and see if anything catches fire.""" batch = 7 input_size = 11 hidden_size = 9 quant_desc_input = tensor_quant.QuantDescriptor(num_bits=8) quant_desc_weight = tensor_quant.QuantDescriptor(num_bits=8, axis=(1,)) quant_rnn_object = quant_rnn.QuantLSTMCell(input_size, hidden_size, bias=False, quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight) quant_rnn_object._input_quantizer.disable() quant_rnn_object._weight_quantizer.disable() input = torch.randn(batch, input_size) hidden = torch.randn(batch, hidden_size) cell = torch.randn(batch, hidden_size) quant_rnn_object(input, hx=(hidden, cell))