def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): # type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa if hx is None: num_directions = 2 if self.bidirectional else 1 hx = torch.zeros(self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, device=input.device) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. hx = self.permute_hidden(hx, sorted_indices) self.check_forward_args(input, hx, batch_sizes) if batch_sizes is None: result = torch.quantized_gru(input, hx, self.all_weights, self.bias, self.num_layers, float(self.dropout), self.training, self.bidirectional, self.batch_first) else: result = torch.quantized_gru(input, batch_sizes, hx, self.all_weights, self.bias, self.num_layers, float(self.dropout), self.training, self.bidirectional) output = result[0] hidden = result[1] return output, hidden
def forward_impl( self, input: Tensor, hx: Optional[Tensor], batch_sizes: Optional[Tensor], max_batch_size: int, sorted_indices: Optional[Tensor]) -> Tuple[Tensor, Tensor]: if hx is None: num_directions = 2 if self.bidirectional else 1 zeros = torch.zeros(self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, device=input.device) hx = zeros else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. hx = self.permute_hidden(hx, sorted_indices) self.check_forward_args(input, hx, batch_sizes) _all_params = ([m.param for m in self._all_weight_values]) if batch_sizes is None: result = torch.quantized_gru(input, hx, _all_params, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, self.batch_first) else: result = torch.quantized_gru(input, batch_sizes, hx, _all_params, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional) output = result[0] hidden = result[1] return output, hidden
def test_gru_api(self): r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16 """ # Check that module matches the numerics of the op and ensure that module can be # instantiated for all engines and dtypes for dtype in [torch.qint8, torch.float16]: if dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack": # fp16 dynamic quant is not supported for qnnpack continue # Test default instantiation seq_len = 4 batch = 2 input_size = 3 hidden_size = 7 num_layers = 2 bias = True bidirectional = False x = torch.rand(seq_len, batch, input_size) h = torch.rand(num_layers * (bidirectional + 1), batch, hidden_size) cell_dq = torch.nn.quantized.dynamic.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=False, dropout=0.0, bidirectional=bidirectional, dtype=dtype) _all_params = ([m.param for m in cell_dq._all_weight_values]) result = torch.quantized_gru(x, h, _all_params, cell_dq.bias, cell_dq.num_layers, float(cell_dq.dropout), False, bidirectional, False) y, h = cell_dq(x, h) self.assertEqual(result[0], y, msg="GRU module API failed") self.assertEqual(result[1], h, msg="GRU module API failed")