Exemple #1
0
    def forward_impl(self, input, hx, batch_sizes, max_batch_size,
                     sorted_indices):
        # type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor]  # noqa
        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            hx = torch.zeros(self.num_layers * num_directions,
                             max_batch_size,
                             self.hidden_size,
                             dtype=input.dtype,
                             device=input.device)
        else:
            # Each batch of the hidden state should match the input sequence that
            # the user believes he/she is passing in.
            hx = self.permute_hidden(hx, sorted_indices)

        self.check_forward_args(input, hx, batch_sizes)
        if batch_sizes is None:
            result = torch.quantized_gru(input, hx, self.all_weights,
                                         self.bias, self.num_layers,
                                         float(self.dropout), self.training,
                                         self.bidirectional, self.batch_first)
        else:
            result = torch.quantized_gru(input, batch_sizes, hx,
                                         self.all_weights,
                                         self.bias, self.num_layers,
                                         float(self.dropout), self.training,
                                         self.bidirectional)

        output = result[0]
        hidden = result[1]

        return output, hidden
Exemple #2
0
    def forward_impl(
            self, input: Tensor, hx: Optional[Tensor],
            batch_sizes: Optional[Tensor], max_batch_size: int,
            sorted_indices: Optional[Tensor]) -> Tuple[Tensor, Tensor]:
        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            zeros = torch.zeros(self.num_layers * num_directions,
                                max_batch_size,
                                self.hidden_size,
                                dtype=input.dtype,
                                device=input.device)
            hx = zeros
        else:
            # Each batch of the hidden state should match the input sequence that
            # the user believes he/she is passing in.
            hx = self.permute_hidden(hx, sorted_indices)

        self.check_forward_args(input, hx, batch_sizes)

        _all_params = ([m.param for m in self._all_weight_values])
        if batch_sizes is None:
            result = torch.quantized_gru(input, hx, _all_params, self.bias,
                                         self.num_layers, self.dropout,
                                         self.training, self.bidirectional,
                                         self.batch_first)
        else:
            result = torch.quantized_gru(input, batch_sizes, hx, _all_params,
                                         self.bias, self.num_layers,
                                         self.dropout, self.training,
                                         self.bidirectional)
        output = result[0]
        hidden = result[1]

        return output, hidden
    def test_gru_api(self):
        r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16
        """
        # Check that module matches the numerics of the op and ensure that module can be
        # instantiated for all engines and dtypes

        for dtype in [torch.qint8, torch.float16]:
            if dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack":
                # fp16 dynamic quant is not supported for qnnpack
                continue
                # Test default instantiation
            seq_len = 4
            batch = 2
            input_size = 3
            hidden_size = 7
            num_layers = 2
            bias = True
            bidirectional = False

            x = torch.rand(seq_len, batch, input_size)
            h = torch.rand(num_layers * (bidirectional + 1), batch, hidden_size)


            cell_dq = torch.nn.quantized.dynamic.GRU(input_size=input_size,
                                                     hidden_size=hidden_size,
                                                     num_layers=num_layers,
                                                     bias=bias,
                                                     batch_first=False,
                                                     dropout=0.0,
                                                     bidirectional=bidirectional,
                                                     dtype=dtype)

            _all_params = ([m.param for m in cell_dq._all_weight_values])
            result = torch.quantized_gru(x,
                                         h,
                                         _all_params,
                                         cell_dq.bias,
                                         cell_dq.num_layers,
                                         float(cell_dq.dropout),
                                         False,
                                         bidirectional,
                                         False)


            y, h = cell_dq(x, h)
            self.assertEqual(result[0], y, msg="GRU module API failed")
            self.assertEqual(result[1], h, msg="GRU module API failed")