def forward_impl( self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]], batch_sizes: Optional[Tensor], max_batch_size: int, sorted_indices: Optional[Tensor] ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: if hx is None: num_directions = 2 if self.bidirectional else 1 zeros = torch.zeros(self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, device=input.device) hx = (zeros, zeros) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. hx = self.permute_hidden(hx, sorted_indices) self.check_forward_args(input, hx, batch_sizes) _all_params = ([m.param for m in self._all_weight_values]) if batch_sizes is None: result = torch.quantized_lstm(input, hx, _all_params, self.bias, self.num_layers, float(self.dropout), self.training, self.bidirectional, self.batch_first, dtype=self.dtype, use_dynamic=True) else: result = torch.quantized_lstm(input, batch_sizes, hx, _all_params, self.bias, self.num_layers, float(self.dropout), self.training, self.bidirectional, dtype=self.dtype, use_dynamic=True) output = result[0] hidden = result[1:] return output, hidden
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): # type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa if hx is None: num_directions = 2 if self.bidirectional else 1 zeros = torch.zeros(self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, device=input.device) hx = (zeros, zeros) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. hx = self.permute_hidden(hx, sorted_indices) self.check_forward_args(input, hx, batch_sizes) assert batch_sizes is None result = torch.quantized_lstm(input, hx, self.all_weights, self.bias, self.num_layers, float(self.dropout), self.training, self.bidirectional, self.batch_first, dtype=self.dtype, use_dynamic=False) output = result[0] hidden = result[1:] return output, hidden
def test_lstm_api(self): r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16 """ # Check that module matches the numerics of the op and ensure that module can be # instantiated for all engines and dtypes for dtype in [torch.qint8, torch.float16]: if dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack": # fp16 dynamic quant is not supported for qnnpack continue # Test default instantiation seq_len = 4 batch = 2 input_size = 3 hidden_size = 7 num_layers = 2 bias = True bidirectional = False x = torch.randn(seq_len, batch, input_size) h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size) c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size) cell_dq = torch.nn.quantized.dynamic.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=False, dropout=0.0, bidirectional=bidirectional, dtype=dtype) _all_params = ([m.param for m in cell_dq._all_weight_values]) result = torch.quantized_lstm(x, (h, c), _all_params, cell_dq.bias, cell_dq.num_layers, float(cell_dq.dropout), False, bidirectional, False, dtype=dtype, use_dynamic=True) y, (h, c) = cell_dq(x, (h, c)) self.assertEqual(result[0], y) self.assertEqual(result[1], h) self.assertEqual(result[2], c)
def test_lstm_api(self, dtype, bidirectional): r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16 """ # Check that module matches the numerics of the op and ensure that module can be # instantiated for all engines and dtypes seq_len = 4 batch = 2 input_size = 3 hidden_size = 7 num_layers = 2 bias = True weight_keys = [] bias_keys = [] num_directions = 2 if bidirectional else 1 for layer in range(num_layers): for direction in range(num_directions): suffix = '_reverse' if direction == 1 else '' key_name1 = 'weight_ih_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix) key_name2 = 'weight_hh_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix) weight_keys.append(key_name1) weight_keys.append(key_name2) key_name1 = 'bias_ih_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix) key_name2 = 'bias_hh_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix) bias_keys.append(key_name1) bias_keys.append(key_name2) if not (dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack"): # fp16 dynamic quant is not supported for qnnpack x = torch.randn(seq_len, batch, input_size) h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size) c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size) cell_dq = torch.nn.quantized.dynamic.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=False, dropout=0.0, bidirectional=bidirectional, dtype=dtype) ref_dq = torch.nn.quantized.dynamic.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=False, dropout=0.0, bidirectional=bidirectional, dtype=dtype) _all_params = ([m.param for m in cell_dq._all_weight_values]) result = torch.quantized_lstm(x, (h, c), _all_params, cell_dq.bias, cell_dq.num_layers, float(cell_dq.dropout), False, bidirectional, False, dtype=dtype, use_dynamic=True) y, (h, c) = cell_dq(x, (h, c)) self.assertEqual(result[0], y) self.assertEqual(result[1], h) self.assertEqual(result[2], c) x = torch.randn(10, 20, 3) self.check_eager_serialization(cell_dq, ref_dq, [x]) self.check_weight_bias_api(cell_dq, weight_keys, bias_keys)