Esempio n. 1
0
    def test_GRU_runtime_with_cond(self):
        # This test is to demonstrate the graph rewrite of grappler plugin under
        # the condition that the function returns different number of internal
        # states.
        layer = keras.layers.GRU(self.rnn_state_size, return_runtime=True)

        inputs = keras.layers.Input(shape=[self.timestep, self.input_shape],
                                    dtype=tf.float32)

        zeros = tf.zeros([self.batch, self.output_shape])
        dummy_runtime = gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_UNKNOWN)
        a = tf.constant(0)
        b = tf.constant(1)
        # Will always run the GRU layer.
        outputs, runtime = tf.cond(tf.less(a, b), lambda: layer(inputs),
                                   lambda: (zeros, dummy_runtime))

        # Expand the runtime so that it is a 1D tensor instead of scalar.
        # TF model does not work with scalar model output, specially during
        # aggregation.
        runtime = keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(
            runtime)
        model = keras.models.Model(inputs=inputs, outputs=[outputs, runtime])
        self._test_runtime_with_model(model)
Esempio n. 2
0
def gpu_gru(
    inputs,
    init_h,
    kernel,
    recurrent_kernel,
    bias,
    mask,
    time_major,
    go_backwards,
    sequence_lengths,
    return_sequences,
):
    """GRU with cuDNN implementation which is only available for GPU."""
    if mask is not None:
        sequence_lengths = gru_lstm_utils.calculate_sequence_by_mask(
            mask, time_major)

    if not time_major and sequence_lengths is None:
        inputs = tf.transpose(inputs, perm=(1, 0, 2))
        seq_axis, batch_axis = (0, 1)
    else:
        seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
    # For init_h, cuDNN expects one more dim of num_layers before or after batch
    # dim for time major or batch major inputs respectively
    init_h = tf.expand_dims(init_h, axis=seq_axis)

    weights = tf.split(kernel, 3, axis=1)
    weights += tf.split(recurrent_kernel, 3, axis=1)
    # Note that the bias was initialized as shape (2, 3 * units), flat it into
    # (6 * units)
    bias = tf.split(backend.flatten(bias), 6)

    if tf.sysconfig.get_build_info()["is_cuda_build"]:
        # Note that the gate order for cuDNN is different from the canonical format.
        # canonical format is [z, r, h], whereas cuDNN is [r, z, h]. The swap need
        # to be done for kernel, recurrent_kernel, input_bias, recurrent_bias.
        # z is update gate weights.
        # r is reset gate weights.
        # h is output gate weights.
        weights[0], weights[1] = weights[1], weights[0]
        weights[3], weights[4] = weights[4], weights[3]
        bias[0], bias[1] = bias[1], bias[0]
        bias[3], bias[4] = bias[4], bias[3]

    params = gru_lstm_utils.canonical_to_params(
        weights=weights,
        biases=bias,
        shape=tf.constant([-1]),
        transpose_weights=True,
    )

    if sequence_lengths is not None:
        if go_backwards:
            # Three reversals are required. E.g.,
            # normal input = [1, 2, 3, 0, 0]  # where 0 need to be masked
            # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
            # output_from_cudnn = [6, 5, 4, 0, 0]
            # expected_output = [0, 0, 6, 5 ,4]
            inputs = tf.reverse_sequence(
                inputs,
                sequence_lengths,
                seq_axis=seq_axis,
                batch_axis=batch_axis,
            )
        outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
            input=inputs,
            input_h=init_h,
            input_c=0,
            params=params,
            is_training=True,
            rnn_mode="gru",
            sequence_lengths=sequence_lengths,
            time_major=time_major,
        )
        if go_backwards:
            outputs = tf.reverse_sequence(
                outputs,
                sequence_lengths,
                seq_axis=seq_axis,
                batch_axis=batch_axis,
            )
            outputs = tf.reverse(outputs, axis=[seq_axis])
    else:
        if go_backwards:
            # Reverse axis 0 since the input is already convert to time major.
            inputs = tf.reverse(inputs, axis=[0])
        outputs, h, _, _ = tf.raw_ops.CudnnRNN(
            input=inputs,
            input_h=init_h,
            input_c=0,
            params=params,
            is_training=True,
            rnn_mode="gru",
        )

    last_output = outputs[-1]
    if not time_major and sequence_lengths is None and return_sequences:
        outputs = tf.transpose(outputs, perm=[1, 0, 2])
    h = tf.squeeze(h, axis=seq_axis)

    # In the case of variable length input, the cudnn kernel will fill zeros for
    # the output, whereas the default keras behavior is to bring over the previous
    # output for t-1, so that in the return_sequence=False case, user can quickly
    # get the final effect output instead just 0s at the last timestep.
    # In order to mimic the default keras behavior, we copy the final h state as
    # the last_output, since it is numerically same as the output.
    if sequence_lengths is not None:
        last_output = h

    # Match CPU return format
    if not return_sequences:
        outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)

    return (
        last_output,
        outputs,
        h,
        gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_GPU),
    )
Esempio n. 3
0
def standard_gru(
    inputs,
    init_h,
    kernel,
    recurrent_kernel,
    bias,
    mask,
    time_major,
    go_backwards,
    sequence_lengths,
    zero_output_for_mask,
    return_sequences,
):
    """GRU with standard kernel implementation.

    This implementation can be run on all types of hardware.

    This implementation lifts out all the layer weights and make them function
    parameters. It has same number of tensor input params as the cuDNN
    counterpart. The RNN step logic has been simplified, eg dropout and mask is
    removed since cuDNN implementation does not support that.

    Args:
      inputs: Input tensor of GRU layer.
      init_h: Initial state tensor for the cell output.
      kernel: Weights for cell kernel.
      recurrent_kernel: Weights for cell recurrent kernel.
      bias: Weights for cell kernel bias and recurrent bias. The bias contains the
        combined input_bias and recurrent_bias.
      mask: Binary tensor of shape `(samples, timesteps)` indicating whether
        a given timestep should be masked. An individual `True` entry indicates
        that the corresponding timestep should be utilized, while a `False` entry
        indicates that the corresponding timestep should be ignored.
      time_major: Boolean, whether the inputs are in the format of
        [time, batch, feature] or [batch, time, feature].
      go_backwards: Boolean (default False). If True, process the input sequence
        backwards and return the reversed sequence.
      sequence_lengths: The lengths of all sequences coming from a variable length
        input, such as ragged tensors. If the input has a fixed timestep size,
        this should be None.
      zero_output_for_mask: Boolean, whether to output zero for masked timestep.
      return_sequences: Boolean. If True, return the recurrent outputs for all
        timesteps in the sequence. If False, only return the output for the
        last timestep (which consumes less memory).

    Returns:
      last_output: output tensor for the last timestep, which has shape
        [batch, units].
      outputs:
        - If `return_sequences=True`: output tensor for all timesteps,
          which has shape [batch, time, units].
        - Else, a tensor equal to `last_output` with shape [batch, 1, units]
      state_0: the cell output, which has same shape as init_h.
      runtime: constant string tensor which indicate real runtime hardware. This
        value is for testing purpose and should be used by user.
    """
    input_shape = backend.int_shape(inputs)
    timesteps = input_shape[0] if time_major else input_shape[1]

    input_bias, recurrent_bias = tf.unstack(bias)

    def step(cell_inputs, cell_states):
        """Step function that will be used by Keras RNN backend."""
        h_tm1 = cell_states[0]

        # inputs projected by all gate matrices at once
        matrix_x = backend.dot(cell_inputs, kernel)
        matrix_x = backend.bias_add(matrix_x, input_bias)

        x_z, x_r, x_h = tf.split(matrix_x, 3, axis=1)

        # hidden state projected by all gate matrices at once
        matrix_inner = backend.dot(h_tm1, recurrent_kernel)
        matrix_inner = backend.bias_add(matrix_inner, recurrent_bias)

        recurrent_z, recurrent_r, recurrent_h = tf.split(matrix_inner,
                                                         3,
                                                         axis=1)
        z = tf.sigmoid(x_z + recurrent_z)
        r = tf.sigmoid(x_r + recurrent_r)
        hh = tf.tanh(x_h + r * recurrent_h)

        # previous and candidate state mixed by update gate
        h = z * h_tm1 + (1 - z) * hh
        return h, [h]

    last_output, outputs, new_states = backend.rnn(
        step,
        inputs,
        [init_h],
        constants=None,
        unroll=False,
        time_major=time_major,
        mask=mask,
        go_backwards=go_backwards,
        input_length=sequence_lengths
        if sequence_lengths is not None else timesteps,
        zero_output_for_mask=zero_output_for_mask,
        return_all_outputs=return_sequences,
    )
    return (
        last_output,
        outputs,
        new_states[0],
        gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_CPU),
    )
Esempio n. 4
0
    def call(self, inputs, mask=None, training=None, initial_state=None):
        # The input should be dense, padded with zeros. If a ragged input is fed
        # into the layer, it is padded and the row lengths are used for masking.
        inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
        is_ragged_input = row_lengths is not None
        self._validate_args_if_ragged(is_ragged_input, mask)

        # GRU does not support constants. Ignore it during process.
        inputs, initial_state, _ = self._process_inputs(
            inputs, initial_state, None)

        if isinstance(mask, list):
            mask = mask[0]

        input_shape = backend.int_shape(inputs)
        timesteps = input_shape[0] if self.time_major else input_shape[1]

        if not self._could_use_gpu_kernel:
            kwargs = {"training": training}
            self._maybe_reset_cell_dropout_mask(self.cell)

            def step(cell_inputs, cell_states):
                return self.cell(cell_inputs, cell_states, **kwargs)

            last_output, outputs, states = backend.rnn(
                step,
                inputs,
                initial_state,
                constants=None,
                go_backwards=self.go_backwards,
                mask=mask,
                unroll=self.unroll,
                input_length=row_lengths
                if row_lengths is not None else timesteps,
                time_major=self.time_major,
                zero_output_for_mask=self.zero_output_for_mask,
                return_all_outputs=self.return_sequences,
            )
            # This is a dummy tensor for testing purpose.
            runtime = gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_UNKNOWN)
        else:
            last_output, outputs, runtime, states = self._defun_gru_call(
                inputs, initial_state, training, mask, row_lengths)

        if self.stateful:
            updates = [
                tf.compat.v1.assign(self.states[0],
                                    tf.cast(states[0], self.states[0].dtype))
            ]
            self.add_update(updates)

        if self.return_sequences:
            output = backend.maybe_convert_to_ragged(
                is_ragged_input,
                outputs,
                row_lengths,
                go_backwards=self.go_backwards,
            )
        else:
            output = last_output

        if self.return_state:
            return [output] + list(states)
        elif self._return_runtime:
            return output, runtime
        else:
            return output
Esempio n. 5
0
def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
             time_major, go_backwards, sequence_lengths):
  """LSTM with either cuDNN or ROCm implementation which is only available for GPU.

  Note that currently only right padded data is supported, or the result will be
  polluted by the unmasked data which should be filtered.

  Args:
    inputs: Input tensor of LSTM layer.
    init_h: Initial state tensor for the cell output.
    init_c: Initial state tensor for the cell hidden state.
    kernel: Weights for cell kernel.
    recurrent_kernel: Weights for cell recurrent kernel.
    bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias
      is used in this case.
    mask: Boolean tensor for mask out the steps within sequence. An individual
      `True` entry indicates that the corresponding timestep should be utilized,
      while a `False` entry indicates that the corresponding timestep should be
      ignored.
    time_major: Boolean, whether the inputs are in the format of [time, batch,
      feature] or [batch, time, feature].
    go_backwards: Boolean (default False). If True, process the input sequence
      backwards and return the reversed sequence.
    sequence_lengths: The lengths of all sequences coming from a variable length
      input, such as ragged tensors. If the input has a fixed timestep size,
      this should be None.

  Returns:
    last_output: Output tensor for the last timestep, which has shape
      [batch, units].
    outputs: Output tensor for all timesteps, which has shape
      [batch, time, units].
    state_0: The cell output, which has same shape as init_h.
    state_1: The cell hidden state, which has same shape as init_c.
    runtime: Constant string tensor which indicate real runtime hardware. This
      value is for testing purpose and should not be used by user.
  """
  if mask is not None:
    sequence_lengths = gru_lstm_utils.calculate_sequence_by_mask(
        mask, time_major)

  if not time_major and sequence_lengths is None:
    inputs = tf.transpose(inputs, perm=(1, 0, 2))
    seq_axis, batch_axis = (0, 1)
  else:
    seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
  # For init_h and init_c, cuDNN expects one more dim of num_layers before or
  # after batch dim for time major or batch major inputs respectively
  init_h = tf.expand_dims(init_h, axis=seq_axis)
  init_c = tf.expand_dims(init_c, axis=seq_axis)

  weights = tf.split(kernel, 4, axis=1)
  weights += tf.split(recurrent_kernel, 4, axis=1)
  # cuDNN has an extra set of bias for inputs, we disable them (setting to 0),
  # so that mathematically it is same as the canonical LSTM implementation.
  full_bias = tf.concat((tf.zeros_like(bias), bias), 0)

  if tf.sysconfig.get_build_info()['is_rocm_build']:
    # ROCm MIOpen's weight sequence for LSTM is different from both canonical
    # and Cudnn format
    # MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o]
    # i is input gate weights.
    # f is forget gate weights.
    # o is output gate weights.
    # c is cell gate weights.
    weights = [weights[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
    # full_bias is a tensor of shape (8*n,)
    full_bias = tf.split(full_bias, 8, axis=0)
    full_bias = [full_bias[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]

  params = gru_lstm_utils.canonical_to_params(
      weights=weights,
      biases=tf.split(full_bias, 8),
      shape=tf.constant([-1]),
      transpose_weights=True)

  if sequence_lengths is not None:
    if go_backwards:
      # Three reversals are required. E.g.,
      # normal input = [1, 2, 3, 0, 0]  # where 0 need to be masked
      # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
      # output_from_cudnn = [6, 5, 4, 0, 0]
      # expected_output = [0, 0, 6, 5 ,4]
      inputs = tf.reverse_sequence(
          inputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
    outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(
        input=inputs,
        input_h=init_h,
        input_c=init_c,
        params=params,
        is_training=True,
        rnn_mode='lstm',
        sequence_lengths=sequence_lengths,
        time_major=time_major)
    if go_backwards:
      outputs = tf.reverse_sequence(
          outputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
      outputs = tf.reverse(outputs, axis=[seq_axis])
  else:
    # # Fill the array with shape [batch] with value of max timesteps.
    # sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
    #                                  array_ops.shape(inputs)[0])
    if go_backwards:
      # Reverse axis 0 since the input is already convert to time major.
      inputs = tf.reverse(inputs, axis=[0])
    outputs, h, c, _ = tf.raw_ops.CudnnRNN(
        input=inputs, input_h=init_h, input_c=init_c, params=params,
        is_training=True, rnn_mode='lstm')

  last_output = outputs[-1]
  if not time_major and sequence_lengths is None:
    outputs = tf.transpose(outputs, perm=[1, 0, 2])
  h = tf.squeeze(h, axis=seq_axis)
  c = tf.squeeze(c, axis=seq_axis)

  # In the case of variable length input, the cudnn kernel will fill zeros for
  # the output, whereas the default keras behavior is to bring over the previous
  # output for t-1, so that in the return_sequence=False case, user can quickly
  # get the final effect output instead just 0s at the last timestep.
  # In order to mimic the default keras behavior, we copy the final h state as
  # the last_output, since it is numerically same as the output.
  if sequence_lengths is not None:
    last_output = h
  return last_output, outputs, h, c, gru_lstm_utils.runtime(
      gru_lstm_utils.RUNTIME_GPU)
Esempio n. 6
0
def standard_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias,
                  mask, time_major, go_backwards, sequence_lengths,
                  zero_output_for_mask):
  """LSTM with standard kernel implementation.

  This implementation can be run on all types for hardware.

  This implementation lifts out all the layer weights and make them function
  parameters. It has same number of tensor input params as the cuDNN
  counterpart. The RNN step logic has been simplified, eg dropout and mask is
  removed since cuDNN implementation does not support that.

  Note that the first half of the bias tensor should be ignored by this impl.
  The cuDNN impl need an extra set of input gate bias. In order to make the both
  function take same shape of parameter, that extra set of bias is also feed
  here.

  Args:
    inputs: input tensor of LSTM layer.
    init_h: initial state tensor for the cell output.
    init_c: initial state tensor for the cell hidden state.
    kernel: weights for cell kernel.
    recurrent_kernel: weights for cell recurrent kernel.
    bias: weights for cell kernel bias and recurrent bias. Only recurrent bias
      is used in this case.
    mask: Boolean tensor for mask out the steps within sequence.
      An individual `True` entry indicates that the corresponding timestep
      should be utilized, while a `False` entry indicates that the corresponding
      timestep should be ignored.
    time_major: boolean, whether the inputs are in the format of
      [time, batch, feature] or [batch, time, feature].
    go_backwards: Boolean (default False). If True, process the input sequence
      backwards and return the reversed sequence.
    sequence_lengths: The lengths of all sequences coming from a variable length
      input, such as ragged tensors. If the input has a fixed timestep size,
      this should be None.
    zero_output_for_mask: Boolean, whether to output zero for masked timestep.

  Returns:
    last_output: output tensor for the last timestep, which has shape
      [batch, units].
    outputs: output tensor for all timesteps, which has shape
      [batch, time, units].
    state_0: the cell output, which has same shape as init_h.
    state_1: the cell hidden state, which has same shape as init_c.
    runtime: constant string tensor which indicate real runtime hardware. This
      value is for testing purpose and should be used by user.
  """
  input_shape = backend.int_shape(inputs)
  timesteps = input_shape[0] if time_major else input_shape[1]

  def step(cell_inputs, cell_states):
    """Step function that will be used by Keras RNN backend."""
    h_tm1 = cell_states[0]  # previous memory state
    c_tm1 = cell_states[1]  # previous carry state

    z = backend.dot(cell_inputs, kernel)
    z += backend.dot(h_tm1, recurrent_kernel)
    z = backend.bias_add(z, bias)

    z0, z1, z2, z3 = tf.split(z, 4, axis=1)

    i = tf.sigmoid(z0)
    f = tf.sigmoid(z1)
    c = f * c_tm1 + i * tf.tanh(z2)
    o = tf.sigmoid(z3)

    h = o * tf.tanh(c)
    return h, [h, c]

  last_output, outputs, new_states = backend.rnn(
      step,
      inputs, [init_h, init_c],
      constants=None,
      unroll=False,
      time_major=time_major,
      mask=mask,
      go_backwards=go_backwards,
      input_length=(sequence_lengths
                    if sequence_lengths is not None else timesteps),
      zero_output_for_mask=zero_output_for_mask)
  return (last_output, outputs, new_states[0], new_states[1],
          gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_CPU))
Esempio n. 7
0
  def call(self, inputs, mask=None, training=None, initial_state=None):
    # The input should be dense, padded with zeros. If a ragged input is fed
    # into the layer, it is padded and the row lengths are used for masking.
    inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
    is_ragged_input = (row_lengths is not None)
    self._validate_args_if_ragged(is_ragged_input, mask)

    # LSTM does not support constants. Ignore it during process.
    inputs, initial_state, _ = self._process_inputs(inputs, initial_state, None)

    if isinstance(mask, list):
      mask = mask[0]

    input_shape = backend.int_shape(inputs)
    timesteps = input_shape[0] if self.time_major else input_shape[1]

    if not self._could_use_gpu_kernel:
      # Fall back to use the normal LSTM.
      kwargs = {'training': training}
      self._maybe_reset_cell_dropout_mask(self.cell)

      def step(inputs, states):
        return self.cell(inputs, states, **kwargs)

      last_output, outputs, states = backend.rnn(
          step,
          inputs,
          initial_state,
          constants=None,
          go_backwards=self.go_backwards,
          mask=mask,
          unroll=self.unroll,
          input_length=row_lengths if row_lengths is not None else timesteps,
          time_major=self.time_major,
          zero_output_for_mask=self.zero_output_for_mask)
      runtime = gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_UNKNOWN)
    else:
      # Use the new defun approach for backend implementation swap.
      # Note that different implementations need to have same function
      # signature, eg, the tensor parameters need to have same shape and dtypes.
      # Since the cuDNN has an extra set of bias, those bias will be passed to
      # both normal and cuDNN implementations.
      self.reset_dropout_mask()
      dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
      if dropout_mask is not None:
        inputs = inputs * dropout_mask[0]
      if gru_lstm_utils.use_new_gru_lstm_impl():
        lstm_kwargs = {
            'inputs':
                inputs,
            'init_h':
                gru_lstm_utils.read_variable_value(initial_state[0]),
            'init_c':
                gru_lstm_utils.read_variable_value(initial_state[1]),
            'kernel':
                gru_lstm_utils.read_variable_value(self.cell.kernel),
            'recurrent_kernel':
                gru_lstm_utils.read_variable_value(self.cell.recurrent_kernel),
            'bias':
                gru_lstm_utils.read_variable_value(self.cell.bias),
            'mask':
                mask,
            'time_major':
                self.time_major,
            'go_backwards':
                self.go_backwards,
            'sequence_lengths':
                row_lengths,
            'zero_output_for_mask':
                self.zero_output_for_mask,
        }
        (last_output, outputs, new_h, new_c,
         runtime) = self._defun_wrapper.defun_layer(**lstm_kwargs)
      else:
        gpu_lstm_kwargs = {
            'inputs':
                inputs,
            'init_h':
                gru_lstm_utils.read_variable_value(initial_state[0]),
            'init_c':
                gru_lstm_utils.read_variable_value(initial_state[1]),
            'kernel':
                gru_lstm_utils.read_variable_value(self.cell.kernel),
            'recurrent_kernel':
                gru_lstm_utils.read_variable_value(self.cell.recurrent_kernel),
            'bias':
                gru_lstm_utils.read_variable_value(self.cell.bias),
            'mask':
                mask,
            'time_major':
                self.time_major,
            'go_backwards':
                self.go_backwards,
            'sequence_lengths':
                row_lengths
        }
        normal_lstm_kwargs = gpu_lstm_kwargs.copy()
        normal_lstm_kwargs.update({
            'zero_output_for_mask': self.zero_output_for_mask,
        })

        if tf.executing_eagerly():
          device_type = gru_lstm_utils.get_context_device_type()
          can_use_gpu = (
              # Either user specified GPU or unspecified but GPU is available.
              (device_type == gru_lstm_utils.GPU_DEVICE_NAME or
               (device_type is None
                and tf.config.list_logical_devices('GPU'))) and
              (mask is None or
               gru_lstm_utils.is_cudnn_supported_inputs(mask, self.time_major)))
          # Under eager context, check the device placement and prefer the
          # GPU implementation when GPU is available.
          if can_use_gpu:
            last_output, outputs, new_h, new_c, runtime = gpu_lstm(
                **gpu_lstm_kwargs)
          else:
            last_output, outputs, new_h, new_c, runtime = standard_lstm(
                **normal_lstm_kwargs)
        else:
          (last_output, outputs, new_h, new_c,
           runtime) = lstm_with_backend_selection(**normal_lstm_kwargs)

      states = [new_h, new_c]

    if self.stateful:
      updates = [
          tf.compat.v1.assign(self_state, tf.cast(state, self_state.dtype))
          for self_state, state in zip(self.states, states)
      ]
      self.add_update(updates)

    if self.return_sequences:
      output = backend.maybe_convert_to_ragged(
          is_ragged_input, outputs, row_lengths, go_backwards=self.go_backwards)
    else:
      output = last_output

    if self.return_state:
      return [output] + list(states)
    elif self.return_runtime:
      return output, runtime
    else:
      return output