def test_quantized_rnn(self): d_in, d_hid = 2, 2 model = LSTMDynamicModel().eval() cell = model.lstm # Replace parameter values s.t. the range of values is exactly # 255, thus we will have 0 quantization error in the quantized # GEMM call. This i s for testing purposes. # # Note that the current implementation does not support # accumulation values outside of the range representable by a # 16 bit integer, instead resulting in a saturated value. We # must take care that in our test we do not end up with a dot # product that overflows the int16 range, e.g. # (255*127+255*127) = 64770. So, we hardcode the test values # here and ensure a mix of signedness. vals = [[100, -155], [100, -155], [-155, 100], [-155, 100], [100, -155], [-155, 100], [-155, 100], [100, -155]] if isinstance(cell, torch.nn.LSTM): num_chunks = 4 vals = vals[:d_hid * num_chunks] cell.weight_ih_l0 = torch.nn.Parameter( torch.tensor(vals, dtype=torch.float), requires_grad=False) cell.weight_hh_l0 = torch.nn.Parameter( torch.tensor(vals, dtype=torch.float), requires_grad=False) ref = copy.deepcopy(cell) model_int8 = quantize_dynamic(model=model, dtype=torch.qint8) model_fp16 = quantize_dynamic(model=model, dtype=torch.float16) # Smoke test extra reprs self.assertTrue('DynamicQuantizedLSTM' in str(model_int8)) self.assertTrue('DynamicQuantizedLSTM' in str(model_fp16)) cell_int8 = model_int8.lstm cell_fp16 = model_fp16.lstm assert type(cell_int8) == torch.nn.quantized.dynamic.LSTM, \ 'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic' assert type(cell_fp16) == torch.nn.quantized.dynamic.LSTM, \ 'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic' niter = 10 x = torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) h0_vals = [[-155, 100], [-155, 155], [100, -155]] hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) if isinstance(ref, torch.nn.LSTM): hiddens = (hx, cx) ref_out, ref_hid = ref(x, hiddens) # Compare int8 quantized to unquantized output_int8, final_hiddens_int8 = cell_int8(x, hiddens) torch.testing.assert_allclose(output_int8, ref_out) self.assertEqual(output_int8, ref_out) for out_val, ref_val in zip(final_hiddens_int8, ref_hid): torch.testing.assert_allclose(out_val, ref_val) class ScriptWrapper(torch.nn.Module): def __init__(self, cell): super(ScriptWrapper, self).__init__() self.cell = cell def forward(self, x, hiddens): # type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] return self.cell(x, hiddens) # TODO: TorchScript overloads don't work without this wrapper cell_script = torch.jit.script(ScriptWrapper(cell_int8)) out_script, hid_script = cell_script(x, hiddens) self.assertEqual(len(out_script), len(ref_out)) for out_val, ref_val in zip(out_script, ref_out): torch.testing.assert_allclose(out_val, ref_val) # Test save/load b = io.BytesIO() torch.jit.save(cell_script, b) b.seek(0) loaded = torch.jit.load(b) out_loaded, hid_loaded = loaded(x, hiddens) for loaded_val, ref_val in zip(out_loaded, ref_out): torch.testing.assert_allclose(loaded_val, ref_val) # Compare fp16 quantized to unquantized output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens) torch.testing.assert_allclose(output_fp16, ref_out) self.assertEqual(output_fp16, ref_out) for out, ref in zip(final_hiddens_fp16, ref_hid): torch.testing.assert_allclose(out, ref)
def test_quantized_rnn(self): d_in, d_hid = 2, 2 model = LSTMDynamicModel().eval() cell = model.lstm # Replace parameter values s.t. the range of values is exactly # 255, thus we will have 0 quantization error in the quantized # GEMM call. This i s for testing purposes. # # Note that the current implementation does not support # accumulation values outside of the range representable by a # 16 bit integer, instead resulting in a saturated value. We # must take care that in our test we do not end up with a dot # product that overflows the int16 range, e.g. # (255*127+255*127) = 64770. So, we hardcode the test values # here and ensure a mix of signedness. vals = [[100, -155], [100, -155], [-155, 100], [-155, 100], [100, -155], [-155, 100], [-155, 100], [100, -155]] if isinstance(cell, torch.nn.LSTM): num_chunks = 4 vals = vals[:d_hid * num_chunks] cell.weight_ih_l0 = torch.nn.Parameter(torch.tensor(vals, dtype=torch.float), requires_grad=False) cell.weight_hh_l0 = torch.nn.Parameter(torch.tensor(vals, dtype=torch.float), requires_grad=False) ref = copy.deepcopy(cell) qconfig_dynamic_dict = { torch.nn.LSTM: default_dynamic_qconfig, } default_dynamic_module_mapping = { torch.nn.LSTM: torch.nn.quantized.dynamic.LSTM, } model_int8 = quantize_dynamic(model, qconfig_dynamic_dict, default_dynamic_module_mapping) cell_int8 = model_int8.lstm assert type(cell_int8) == torch.nn.quantized.dynamic.LSTM, \ 'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic' niter = 10 x = torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) h0_vals = [[-155, 100], [-155, 155], [100, -155]] hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) if isinstance(ref, torch.nn.LSTM): hiddens = (hx, cx) ref_out, ref_hid = ref(x, hiddens) # Compare int8 quantized to unquantized output_int8, final_hiddens_int8 = cell_int8(x, hiddens) torch.testing.assert_allclose(output_int8, ref_out) self.assertEqual(output_int8, ref_out) for out, ref in zip(final_hiddens_int8, ref_hid): torch.testing.assert_allclose(out, ref)