Esempio n. 1
0
    def forward_impl(
        self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]],
        batch_sizes: Optional[Tensor], max_batch_size: int,
        sorted_indices: Optional[Tensor]
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            zeros = torch.zeros(self.num_layers * num_directions,
                                max_batch_size, self.hidden_size,
                                dtype=input.dtype, device=input.device)
            hx = (zeros, zeros)
        else:
            # Each batch of the hidden state should match the input sequence that
            # the user believes he/she is passing in.
            hx = self.permute_hidden(hx, sorted_indices)

        self.check_forward_args(input, hx, batch_sizes)

        _all_params = ([m.param for m in self._all_weight_values])
        if batch_sizes is None:
            result = torch.quantized_lstm(input, hx, _all_params, self.bias, self.num_layers,
                                          float(self.dropout), self.training, self.bidirectional,
                                          self.batch_first, dtype=self.dtype, use_dynamic=True)
        else:
            result = torch.quantized_lstm(input, batch_sizes, hx, _all_params, self.bias,
                                          self.num_layers, float(self.dropout), self.training,
                                          self.bidirectional, dtype=self.dtype, use_dynamic=True)
        output = result[0]
        hidden = result[1:]

        return output, hidden
Esempio n. 2
0
    def forward_impl(self, input, hx, batch_sizes, max_batch_size,
                     sorted_indices):
        # type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]  # noqa
        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            zeros = torch.zeros(self.num_layers * num_directions,
                                max_batch_size,
                                self.hidden_size,
                                dtype=input.dtype,
                                device=input.device)
            hx = (zeros, zeros)
        else:
            # Each batch of the hidden state should match the input sequence that
            # the user believes he/she is passing in.
            hx = self.permute_hidden(hx, sorted_indices)

        self.check_forward_args(input, hx, batch_sizes)
        assert batch_sizes is None
        result = torch.quantized_lstm(input,
                                      hx,
                                      self.all_weights,
                                      self.bias,
                                      self.num_layers,
                                      float(self.dropout),
                                      self.training,
                                      self.bidirectional,
                                      self.batch_first,
                                      dtype=self.dtype,
                                      use_dynamic=False)
        output = result[0]
        hidden = result[1:]

        return output, hidden
    def test_lstm_api(self):
        r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16
        """
        # Check that module matches the numerics of the op and ensure that module can be
        # instantiated for all engines and dtypes

        for dtype in [torch.qint8, torch.float16]:
            if dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack":
                # fp16 dynamic quant is not supported for qnnpack
                continue
                # Test default instantiation
            seq_len = 4
            batch = 2
            input_size = 3
            hidden_size = 7
            num_layers = 2
            bias = True
            bidirectional = False

            x = torch.randn(seq_len, batch, input_size)
            h = torch.randn(num_layers * (bidirectional + 1), batch,
                            hidden_size)
            c = torch.randn(num_layers * (bidirectional + 1), batch,
                            hidden_size)

            cell_dq = torch.nn.quantized.dynamic.LSTM(
                input_size=input_size,
                hidden_size=hidden_size,
                num_layers=num_layers,
                bias=bias,
                batch_first=False,
                dropout=0.0,
                bidirectional=bidirectional,
                dtype=dtype)
            _all_params = ([m.param for m in cell_dq._all_weight_values])
            result = torch.quantized_lstm(x, (h, c),
                                          _all_params,
                                          cell_dq.bias,
                                          cell_dq.num_layers,
                                          float(cell_dq.dropout),
                                          False,
                                          bidirectional,
                                          False,
                                          dtype=dtype,
                                          use_dynamic=True)

            y, (h, c) = cell_dq(x, (h, c))
            self.assertEqual(result[0], y)
            self.assertEqual(result[1], h)
            self.assertEqual(result[2], c)
Esempio n. 4
0
    def test_lstm_api(self, dtype, bidirectional):
        r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16
        """
        # Check that module matches the numerics of the op and ensure that module can be
        # instantiated for all engines and dtypes
        seq_len = 4
        batch = 2
        input_size = 3
        hidden_size = 7
        num_layers = 2
        bias = True
        weight_keys = []
        bias_keys = []
        num_directions = 2 if bidirectional else 1
        for layer in range(num_layers):
            for direction in range(num_directions):
                suffix = '_reverse' if direction == 1 else ''
                key_name1 = 'weight_ih_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix)
                key_name2 = 'weight_hh_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix)
                weight_keys.append(key_name1)
                weight_keys.append(key_name2)
                key_name1 = 'bias_ih_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix)
                key_name2 = 'bias_hh_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix)
                bias_keys.append(key_name1)
                bias_keys.append(key_name2)

        if not (dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack"):
            # fp16 dynamic quant is not supported for qnnpack
            x = torch.randn(seq_len, batch, input_size)
            h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
            c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
            cell_dq = torch.nn.quantized.dynamic.LSTM(input_size=input_size,
                                                      hidden_size=hidden_size,
                                                      num_layers=num_layers,
                                                      bias=bias,
                                                      batch_first=False,
                                                      dropout=0.0,
                                                      bidirectional=bidirectional,
                                                      dtype=dtype)
            ref_dq = torch.nn.quantized.dynamic.LSTM(input_size=input_size,
                                                     hidden_size=hidden_size,
                                                     num_layers=num_layers,
                                                     bias=bias,
                                                     batch_first=False,
                                                     dropout=0.0,
                                                     bidirectional=bidirectional,
                                                     dtype=dtype)

            _all_params = ([m.param for m in cell_dq._all_weight_values])
            result = torch.quantized_lstm(x, (h, c),
                                          _all_params,
                                          cell_dq.bias,
                                          cell_dq.num_layers,
                                          float(cell_dq.dropout),
                                          False,
                                          bidirectional,
                                          False,
                                          dtype=dtype,
                                          use_dynamic=True)


            y, (h, c) = cell_dq(x, (h, c))
            self.assertEqual(result[0], y)
            self.assertEqual(result[1], h)
            self.assertEqual(result[2], c)
            x = torch.randn(10, 20, 3)
            self.check_eager_serialization(cell_dq, ref_dq, [x])
            self.check_weight_bias_api(cell_dq, weight_keys, bias_keys)