Exemple #1
0
    def _call_cell(self,
                   inputs,
                   initial_cell_state=None,
                   initial_output=None,
                   dtype=None,
                   sequence_length=None):
        """Run this LSTM on inputs, starting from the given state.

    Args:
      inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
      initial_cell_state: initial value for cell state, shape `[batch_size,
        self._num_units]`
      initial_output: initial value of cell output, shape `[batch_size,
        self._num_units]`
      dtype: The data type for the initial state and expected output.
      sequence_length: Specifies the length of each sequence in inputs. An
        `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
        time_len)` or None.

    Returns:
      A pair containing:

      - Cell state (cs): A `3-D` tensor of shape `[time_len, batch_size,
                         output_size]`
      - Output (h): A `3-D` tensor of shape `[time_len, batch_size,
                    output_size]`
    """

        inputs_shape = inputs.get_shape().with_rank(3)
        time_len = inputs_shape[0].value
        if time_len is None:
            time_len = array_ops.shape(inputs)[0]

        if self._use_peephole:
            wci = self._w_i_diag
            wco = self._w_o_diag
            wcf = self._w_f_diag
        else:
            wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype)

        if sequence_length is None:
            max_seq_len = math_ops.to_int64(time_len)
        else:
            max_seq_len = math_ops.to_int64(
                math_ops.reduce_max(sequence_length))

        _, cs, _, _, _, _, h = gen_lstm_ops.block_lstm(
            seq_len_max=max_seq_len,
            x=inputs,
            cs_prev=initial_cell_state,
            h_prev=initial_output,
            w=self._kernel,
            wci=wci,
            wcf=wcf,
            wco=wco,
            b=self._bias,
            forget_bias=self._forget_bias,
            cell_clip=self._cell_clip,
            use_peephole=self._use_peephole)
        return cs, h
Exemple #2
0
  def _call_cell(self,
                 inputs,
                 initial_cell_state=None,
                 initial_output=None,
                 dtype=None,
                 sequence_length=None):
    """Run this LSTM on inputs, starting from the given state.

    Args:
      inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
      initial_cell_state: initial value for cell state, shape `[batch_size,
        self._num_units]`
      initial_output: initial value of cell output, shape `[batch_size,
        self._num_units]`
      dtype: The data type for the initial state and expected output.
      sequence_length: Specifies the length of each sequence in inputs. An
        `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
        time_len)` or None.

    Returns:
      A pair containing:

      - Cell state (cs): A `3-D` tensor of shape `[time_len, batch_size,
                         output_size]`
      - Output (h): A `3-D` tensor of shape `[time_len, batch_size,
                    output_size]`
    """

    inputs_shape = inputs.get_shape().with_rank(3)
    time_len = inputs_shape.dims[0].value
    if time_len is None:
      time_len = array_ops.shape(inputs)[0]

    if self._use_peephole:
      wci = self._w_i_diag
      wco = self._w_o_diag
      wcf = self._w_f_diag
    else:
      wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype)

    if sequence_length is None:
      max_seq_len = math_ops.cast(time_len, dtypes.int64)
    else:
      max_seq_len = math_ops.cast(math_ops.reduce_max(sequence_length),
                                  dtypes.int64)

    _, cs, _, _, _, _, h = gen_lstm_ops.block_lstm(
        seq_len_max=max_seq_len,
        x=inputs,
        cs_prev=initial_cell_state,
        h_prev=initial_output,
        w=self._kernel,
        wci=wci,
        wcf=wcf,
        wco=wco,
        b=self._bias,
        forget_bias=self._forget_bias,
        cell_clip=self._cell_clip,
        use_peephole=self._use_peephole)
    return cs, h
Exemple #3
0
  def _call_cell(self, inputs, initial_cell_state, initial_output, dtype,
                 sequence_length):
    """Run this LSTM on inputs, starting from the given state.

    Args:
      inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
      initial_cell_state: initial value for cell state, shape `[batch_size,
        self._num_units]`
      initial_output: initial value of cell output, shape `[batch_size,
        self._num_units]`
      dtype: The data type for the initial state and expected output.
      sequence_length: Specifies the length of each sequence in inputs. An
        `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
        time_len)` or None.

    Returns:
      A pair containing:

      - Cell state (cs): A `3-D` tensor of shape `[time_len, batch_size,
                         output_size]`
      - Output (h): A `3-D` tensor of shape `[time_len, batch_size,
                    output_size]`
    """

    inputs_shape = inputs.get_shape().with_rank(3)
    time_len = inputs_shape[0].value
    if time_len is None:
      time_len = array_ops.shape(inputs)[0]
    input_size = inputs_shape[2].value
    w = vs.get_variable(
        "kernel",
        [input_size + self._num_units, self._num_units * 4], dtype=dtype)
    b = vs.get_variable(
        "bias", [w.get_shape().with_rank(2)[1]],
        initializer=init_ops.constant_initializer(0.0),
        dtype=dtype)
    if self._use_peephole:
      wci = vs.get_variable("w_i_diag", [self._num_units], dtype=dtype)
      wco = vs.get_variable("w_o_diag", [self._num_units], dtype=dtype)
      wcf = vs.get_variable("w_f_diag", [self._num_units], dtype=dtype)
    else:
      wci = wco = wcf = array_ops.zeros([self._num_units], dtype=dtype)

    if sequence_length is None:
      max_seq_len = math_ops.to_int64(time_len)
    else:
      max_seq_len = math_ops.to_int64(math_ops.reduce_max(sequence_length))

    _, cs, _, _, _, _, h = gen_lstm_ops.block_lstm(
        seq_len_max=max_seq_len,
        x=inputs,
        cs_prev=initial_cell_state,
        h_prev=initial_output,
        w=w,
        wci=wci,
        wco=wco,
        wcf=wcf,
        b=b,
        forget_bias=self._forget_bias,
        cell_clip=self._cell_clip,
        use_peephole=self._use_peephole)
    return cs, h
Exemple #4
0
def _block_lstm(seq_len_max,
                x,
                w,
                b,
                cs_prev=None,
                h_prev=None,
                wci=None,
                wcf=None,
                wco=None,
                forget_bias=None,
                cell_clip=None,
                use_peephole=None,
                name=None):
  r"""TODO(williamchan): add doc.

  Args:
    seq_len_max: A `Tensor` of type `int64`.
    x: A list of at least 1 `Tensor` objects of the same type in: `float32`.
    w: A `Tensor`. Must have the same type as `x`.
    b: A `Tensor`. Must have the same type as `x`.
    cs_prev: A `Tensor`. Must have the same type as `x`.
    h_prev: A `Tensor`. Must have the same type as `x`.
    wci: A `Tensor`. Must have the same type as `x`.
    wcf: A `Tensor`. Must have the same type as `x`.
    wco: A `Tensor`. Must have the same type as `x`.
    forget_bias: An optional `float`. Defaults to `1`.
    cell_clip: An optional `float`. Defaults to `3`.
    use_peephole: An optional `bool`. Defaults to `False`.
    name: A name for the operation (optional).

  Returns:
    A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
    i: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.
    cs: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.
    f: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.
    o: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.
    ci: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.
    co: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.
    h: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.

  Raises:
    ValueError: If `b` does not have a valid shape.
  """
  batch_size = x[0].get_shape().with_rank(2)[0].value
  cell_size4 = b.get_shape().with_rank(1)[0].value
  if cell_size4 is None:
    raise ValueError("`b` shape must not be None.")
  cell_size = cell_size4 / 4
  zero_state = None
  if cs_prev is None or h_prev is None:
    zero_state = array_ops.constant(
        0, dtype=dtypes.float32, shape=[batch_size, cell_size])
  if cs_prev is None:
    cs_prev = zero_state
  if h_prev is None:
    h_prev = zero_state
  if wci is None:
    wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size])
    wco = wci
    wcf = wci

  # pylint: disable=protected-access
  i, cs, f, o, ci, co, h = gen_lstm_ops.block_lstm(
      seq_len_max=seq_len_max,
      x=array_ops.stack(x),
      cs_prev=cs_prev,
      h_prev=h_prev,
      w=w,
      wci=wci,
      wco=wco,
      wcf=wcf,
      b=b,
      forget_bias=forget_bias,
      cell_clip=cell_clip,
      name=name,
      use_peephole=use_peephole)

  return array_ops.unstack(i), array_ops.unstack(cs), array_ops.unstack(
      f), array_ops.unstack(o), array_ops.unstack(ci), array_ops.unstack(
          co), array_ops.unstack(h)
Exemple #5
0
def _block_lstm(seq_len_max,
                x,
                w,
                b,
                cs_prev=None,
                h_prev=None,
                wci=None,
                wcf=None,
                wco=None,
                forget_bias=None,
                cell_clip=None,
                use_peephole=None,
                name=None):
    r"""TODO(williamchan): add doc.

  Args:
    seq_len_max: A `Tensor` of type `int64`.
    x: A list of at least 1 `Tensor` objects of the same type in: `float32`.
    w: A `Tensor`. Must have the same type as `x`.
    b: A `Tensor`. Must have the same type as `x`.
    cs_prev: A `Tensor`. Must have the same type as `x`.
    h_prev: A `Tensor`. Must have the same type as `x`.
    wci: A `Tensor`. Must have the same type as `x`.
    wcf: A `Tensor`. Must have the same type as `x`.
    wco: A `Tensor`. Must have the same type as `x`.
    forget_bias: An optional `float`. Defaults to `1`.
    cell_clip: An optional `float`. Defaults to `-1` (no clipping).
    use_peephole: An optional `bool`. Defaults to `False`.
    name: A name for the operation (optional).

  Returns:
    A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
    i: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.
    cs: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.
    f: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.
    o: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.
    ci: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.
    co: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.
    h: A list with the same number of `Tensor` objects as `x` of `Tensor`
    objects of the same type as x.

  Raises:
    ValueError: If `b` does not have a valid shape.
  """
    batch_size = x[0].get_shape().with_rank(2)[0].value
    cell_size4 = b.get_shape().with_rank(1)[0].value
    if cell_size4 is None:
        raise ValueError("`b` shape must not be None.")
    cell_size = cell_size4 / 4
    zero_state = None
    if cs_prev is None or h_prev is None:
        zero_state = array_ops.constant(0,
                                        dtype=dtypes.float32,
                                        shape=[batch_size, cell_size])
    if cs_prev is None:
        cs_prev = zero_state
    if h_prev is None:
        h_prev = zero_state
    if wci is None:
        wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size])
        wcf = wci
        wco = wci

    # pylint: disable=protected-access
    i, cs, f, o, ci, co, h = gen_lstm_ops.block_lstm(
        seq_len_max=seq_len_max,
        x=array_ops.stack(x),
        cs_prev=cs_prev,
        h_prev=h_prev,
        w=w,
        wci=wci,
        wcf=wcf,
        wco=wco,
        b=b,
        forget_bias=forget_bias,
        cell_clip=cell_clip if cell_clip is not None else -1,
        name=name,
        use_peephole=use_peephole)

    return array_ops.unstack(i), array_ops.unstack(cs), array_ops.unstack(
        f), array_ops.unstack(o), array_ops.unstack(ci), array_ops.unstack(
            co), array_ops.unstack(h)