Exemplo n.º 1
0
  def __init__(self, cell, helper, initial_state, output_layer=None):
    """Initialize BasicDecoder.

    Args:
      cell: An `RNNCell` instance.
      helper: A `Helper` instance.
      initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
        The initial state of the RNNCell.
      output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
        `tf.layers.Dense`. Optional layer to apply to the RNN output prior
        to storing the result or sampling.

    Raises:
      TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
    """
    rnn_cell_impl.assert_like_rnncell("cell", cell)
    if not isinstance(helper, helper_py.Helper):
      raise TypeError("helper must be a Helper, received: %s" % type(helper))
    if (output_layer is not None
        and not isinstance(output_layer, layers_base.Layer)):
      raise TypeError(
          "output_layer must be a Layer, received: %s" % type(output_layer))
    self._cell = cell
    self._helper = helper
    self._initial_state = initial_state
    self._output_layer = output_layer
Exemplo n.º 2
0
    def __init__(self, cell, sampler, output_layer=None, **kwargs):
        """Initialize BasicDecoder.

    Args:
      cell: An `RNNCell` instance.
      sampler: A `Sampler` instance.
      output_layer: (Optional) An instance of `tf.compat.v1.layers.Layer`, i.e.,
        `tf.compat.v1.layers.Dense`. Optional layer to apply to the RNN output
        prior to storing the result or sampling.
      **kwargs: Other keyward arguments for layer creation.

    Raises:
      TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
    """
        rnn_cell_impl.assert_like_rnncell("cell", cell)
        if not isinstance(sampler, sampler_py.Sampler):
            raise TypeError("sampler must be a Sampler, received: %s" %
                            (sampler, ))
        if (output_layer is not None
                and not isinstance(output_layer, layers.Layer)):
            raise TypeError("output_layer must be a Layer, received: %s" %
                            (output_layer, ))
        self.cell = cell
        self.sampler = sampler
        self.output_layer = output_layer
        super(BasicDecoderV2, self).__init__(**kwargs)
Exemplo n.º 3
0
  def __init__(self,
               cell,
               embedding_classes,
               embedding_size,
               initializer=None,
               reuse=None):
    """Create a cell with an added input embedding.

    Args:
      cell: an RNNCell, an embedding will be put before its inputs.
      embedding_classes: integer, how many symbols will be embedded.
      embedding_size: integer, the size of the vectors we embed into.
      initializer: an initializer to use when creating the embedding;
        if None, the initializer from variable scope or a default one is used.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.

    Raises:
      TypeError: if cell is not an RNNCell.
      ValueError: if embedding_classes is not positive.
    """
    super(EmbeddingWrapper, self).__init__(_reuse=reuse)
    rnn_cell_impl.assert_like_rnncell("cell", cell)
    if embedding_classes <= 0 or embedding_size <= 0:
      raise ValueError("Both embedding_classes and embedding_size must be > 0: "
                       "%d, %d." % (embedding_classes, embedding_size))
    self._cell = cell
    self._embedding_classes = embedding_classes
    self._embedding_size = embedding_size
    self._initializer = initializer
Exemplo n.º 4
0
  def __init__(self,
               cell,
               num_proj,
               activation=None,
               input_size=None,
               reuse=None):
    """Create a cell with input projection.

    Args:
      cell: an RNNCell, a projection of inputs is added before it.
      num_proj: Python integer.  The dimension to project to.
      activation: (optional) an optional activation function.
      input_size: Deprecated and unused.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.

    Raises:
      TypeError: if cell is not an RNNCell.
    """
    super(InputProjectionWrapper, self).__init__(_reuse=reuse)
    if input_size is not None:
      logging.warn("%s: The input_size parameter is deprecated.", self)
    rnn_cell_impl.assert_like_rnncell("cell", cell)
    self._cell = cell
    self._num_proj = num_proj
    self._activation = activation
    self._linear = None
Exemplo n.º 5
0
    def __init__(self, cell, helper, initial_state, output_layer=None):
        rnn_cell_impl.assert_like_rnncell(type(cell), cell)

        self._cell = cell
        self._helper = helper
        self._initial_state = initial_state
        self._output_layer = output_layer
Exemplo n.º 6
0
    def __init__(self, cell, helper, initial_state, output_layer=None):
        '''
        customed Decoder, refer to: https://blog.csdn.net/thriving_fcl/article/details/74165062
        :param cell: 'RNNCell' instance
        :param helper: 'Helper' instance
        :param initial_state: The inistial state of RNNCell -> encoder output
        :param output_layer: tf.layers.Layer -> tf.layers.Dense
        '''
        if parse_version(tf.__version__) >= parse_version('1.10'):
            rnn_cell_impl.assert_like_rnncell(type(cell), cell)
        else:
            if not rnn_cell_impl._like_rnncell(cell):
                raise TypeError('cell must be RNNCell, receiver: %s' %
                                type(cell))

        if not isinstance(helper, helper_py.Helper):
            raise TypeError('helper must be a Helper, received: %s' %
                            type(helper))
        if output_layer is not None and not isinstance(output_layer,
                                                       layers_base.Layer):
            raise TypeError('output_layer must be a Layer, receive: %s' %
                            type(output_layer))
        self._cell = cell
        self._helper = helper
        self._initial_state = initial_state
        self._output_layer = output_layer
Exemplo n.º 7
0
 def __init__(self,
              cell,
              helper,
              initial_state,
              ini_state,
              ini_att,
              output_layer=None):
     """Initialize BasicDecoder.
 Args:
   cell: An `RNNCell` instance.
   helper: A `Helper` instance.
   initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
     The initial state of the RNNCell. it is the initial hidden state obtained from the last time step of encoder
   ini_state: it is just a zero vector with the size of hidden state in order to start saving the hidden states at each time step
   output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
     `tf.layers.Dense`. Optional layer to apply to the RNN output prior
     to storing the result or sampling.
 Raises:
   TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
 """
     rnn_cell_impl.assert_like_rnncell("cell", cell)
     if not isinstance(helper, helper_py.Helper):
         raise TypeError("helper must be a Helper, received: %s" %
                         type(helper))
     if (output_layer is not None
             and not isinstance(output_layer, layers_base.Layer)):
         raise TypeError("output_layer must be a Layer, received: %s" %
                         type(output_layer))
     self._cell = cell
     self._helper = helper
     self._initial_state = initial_state
     self._output_layer = output_layer
     self._ini_state = ini_state
     self._ini_att = ini_att
Exemplo n.º 8
0
    def __init__(self, cell, helper, initial_state, output_layer=None):
        """Initialize CustomDecoder.
		Args:
			cell: An `RNNCell` instance.
			helper: A `Helper` instance.
			initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
				The initial state of the RNNCell.
			output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
				`tf.layers.Dense`. Optional layer to apply to the RNN output prior
				to storing the result or sampling.
		Raises:
			TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
		"""
        rnn_cell_impl.assert_like_rnncell(type(cell), cell)
        if not isinstance(helper, helper_py.Helper):
            raise TypeError("helper must be a Helper, received: %s" %
                            type(helper))
        if (output_layer is not None
                and not isinstance(output_layer, layers_base.Layer)):
            raise TypeError("output_layer must be a Layer, received: %s" %
                            type(output_layer))
        self._cell = cell
        self._helper = helper
        self._initial_state = initial_state
        self._output_layer = output_layer
Exemplo n.º 9
0
    def __init__(self,
                 cell,
                 embedding_classes,
                 embedding_size,
                 initializer=None,
                 reuse=None):
        """Create a cell with an added input embedding.

    Args:
      cell: an RNNCell, an embedding will be put before its inputs.
      embedding_classes: integer, how many symbols will be embedded.
      embedding_size: integer, the size of the vectors we embed into.
      initializer: an initializer to use when creating the embedding;
        if None, the initializer from variable scope or a default one is used.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.

    Raises:
      TypeError: if cell is not an RNNCell.
      ValueError: if embedding_classes is not positive.
    """
        super(EmbeddingWrapper, self).__init__(_reuse=reuse)
        rnn_cell_impl.assert_like_rnncell("cell", cell)
        if embedding_classes <= 0 or embedding_size <= 0:
            raise ValueError(
                "Both embedding_classes and embedding_size must be > 0: "
                "%d, %d." % (embedding_classes, embedding_size))
        self._cell = cell
        self._embedding_classes = embedding_classes
        self._embedding_size = embedding_size
        self._initializer = initializer
Exemplo n.º 10
0
  def __init__(self, cell, sampler, output_layer=None, **kwargs):
    """Initialize BasicDecoder.

    Args:
      cell: An `RNNCell` instance.
      sampler: A `Sampler` instance.
      output_layer: (Optional) An instance of `tf.compat.v1.layers.Layer`, i.e.,
        `tf.compat.v1.layers.Dense`. Optional layer to apply to the RNN output
        prior to storing the result or sampling.
      **kwargs: Other keyward arguments for layer creation.

    Raises:
      TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
    """
    rnn_cell_impl.assert_like_rnncell("cell", cell)
    if not isinstance(sampler, sampler_py.Sampler):
      raise TypeError("sampler must be a Sampler, received: %s" % (sampler,))
    if (output_layer is not None and
        not isinstance(output_layer, layers.Layer)):
      raise TypeError("output_layer must be a Layer, received: %s" %
                      (output_layer,))
    self.cell = cell
    self.sampler = sampler
    self.output_layer = output_layer
    super(BasicDecoderV2, self).__init__(**kwargs)
Exemplo n.º 11
0
    def __init__(self,
                 cell,
                 num_proj,
                 activation=None,
                 input_size=None,
                 reuse=None):
        """Create a cell with input projection.

    Args:
      cell: an RNNCell, a projection of inputs is added before it.
      num_proj: Python integer.  The dimension to project to.
      activation: (optional) an optional activation function.
      input_size: Deprecated and unused.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.

    Raises:
      TypeError: if cell is not an RNNCell.
    """
        super(InputProjectionWrapper, self).__init__(_reuse=reuse)
        if input_size is not None:
            logging.warn("%s: The input_size parameter is deprecated.", self)
        rnn_cell_impl.assert_like_rnncell("cell", cell)
        self._cell = cell
        self._num_proj = num_proj
        self._activation = activation
        self._linear = None
Exemplo n.º 12
0
 def __init__(self,
              cell,
              attn_length,
              attn_size=None,
              attn_vec_size=None,
              input_size=None,
              state_is_tuple=True,
              reuse=None):
     """Create a cell with attention.
     Args:
         cell: an RNNCell, an attention is added to it.
         attn_length: integer, the size of an attention window.
         attn_size: integer, the size of an attention vector. Equal to
             cell.output_size by default.
         attn_vec_size: integer, the number of convolutional features
             calculated on attention state and a size of the hidden layer
             built from base cell state. Equal attn_size to by default.
         input_size: integer, the size of a hidden linear layer, built from
             inputs and attention. Derived from the input tensor by default.
         state_is_tuple: If True, accepted and returned states are n-tuples,
             where `n = len(cells)`. By default (False), the states are all
             concatenated along the column axis.
         reuse: (optional) Python boolean describing whether to reuse
             variables in an existing scope. If not `True`, and the existing
             scope already has the given variables, an error is raised.
     Raises:
         TypeError: if cell is not an RNNCell.
         ValueError: if cell returns a state tuple but the flag
             `state_is_tuple` is `False` or if attn_length is zero or less.
     """
     super(TemporalPatternAttentionCellWrapper, self).__init__(_reuse=reuse)
     rnn_cell_impl.assert_like_rnncell('FillerName', cell)
     #    raise TypeError("The parameter cell is not RNNCell.")
     if nest.is_sequence(cell.state_size) and not state_is_tuple:
         raise ValueError("Cell returns tuple of states, but the flag "
                          "state_is_tuple is not set. State size is: %s" %
                          str(cell.state_size))
     if attn_length <= 0:
         raise ValueError(
             "attn_length should be greater than zero, got %s" %
             str(attn_length))
     if not state_is_tuple:
         logging.warn(
             "%s: Using a concatenated state is slower and will soon be "
             "deprecated.    Use state_is_tuple=True.", self)
     if attn_size is None:
         attn_size = cell.output_size
     if attn_vec_size is None:
         attn_vec_size = attn_size
     self._state_is_tuple = state_is_tuple
     self._cell = cell
     self._attn_vec_size = attn_vec_size
     self._input_size = input_size
     self._attn_size = attn_size
     self._attn_length = attn_length
     self._reuse = reuse
     self._attention_mech = TemporalPatternAttentionMechanism()
Exemplo n.º 13
0
  def __init__(
      self,
      decoder_cell,
      helper,
      initial_decoder_state,
      attention_type,
      spec_layer,
      stop_token_layer,
      prenet=None,
      dtype=dtypes.float32,
      train=True
  ):
    """Initialize TacotronDecoder.

    Args:
      decoder_cell: An `RNNCell` instance.
      helper: A `Helper` instance.
      initial_decoder_state: A (possibly nested tuple of...) tensors and
        TensorArrays. The initial state of the RNNCell.
      attention_type: The type of attention used
      stop_token_layer: An instance of `tf.layers.Layer`, i.e.,
        `tf.layers.Dense`. Stop token layer to apply to the RNN output to
        predict when to stop the decoder
      spec_layer: An instance of `tf.layers.Layer`, i.e.,
        `tf.layers.Dense`. Output layer to apply to the RNN output to map
        the ressult to a spectrogram
      prenet: The prenet to apply to inputs

    Raises:
      TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
    """
    rnn_cell_impl.assert_like_rnncell("cell", decoder_cell)
    if not isinstance(helper, helper_py.Helper):
      raise TypeError("helper must be a Helper, received: %s" % type(helper))
    if (
        spec_layer is not None and
        not isinstance(spec_layer, layers_base.Layer)
    ):
      raise TypeError(
          "spec_layer must be a Layer, received: %s" % type(spec_layer)
      )
    self._decoder_cell = decoder_cell
    self._helper = helper
    self._decoder_initial_state = initial_decoder_state
    self._spec_layer = spec_layer
    self._stop_token_layer = stop_token_layer
    self._attention_type = attention_type
    self._dtype = dtype
    self._prenet = prenet

    if train:
      self._spec_layer = None
      self._stop_token_layer = None
Exemplo n.º 14
0
	def __init__(self, cell, helper, initial_state, output_layer=None):
		rnn_cell_impl.assert_like_rnncell(type(cell), cell)
		if not isinstance(helper, helper_py.Helper):
			raise TypeError("helper must be a Helper, received: %s" % type(helper))
		if (output_layer is not None
				and not isinstance(output_layer, layers_base.Layer)):
			raise TypeError(
					"output_layer must be a Layer, received: %s" % type(output_layer))
		self._cell = cell
		self._helper = helper
		self._initial_state = initial_state
		self._output_layer = output_layer
Exemplo n.º 15
0
    def __init__(self,
                 cell,
                 beam_width,
                 output_layer=None,
                 length_penalty_weight=0.0,
                 coverage_penalty_weight=0.0,
                 reorder_tensor_arrays=True,
                 **kwargs):
        """Initialize the BeamSearchDecoderMixin.

        Args:
          cell: An `RNNCell` instance.
          beam_width:  Python integer, the number of beams.
          output_layer: (Optional) An instance of `tf.keras.layers.Layer`,
            i.e., `tf.keras.layers.Dense`.  Optional layer to apply to the RNN
            output prior to storing the result or sampling.
          length_penalty_weight: Float weight to penalize length. Disabled with
             0.0.
          coverage_penalty_weight: Float weight to penalize the coverage of
            source sentence. Disabled with 0.0.
          reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the
            cell state will be reordered according to the beam search path. If
            the `TensorArray` can be reordered, the stacked form will be
            returned. Otherwise, the `TensorArray` will be returned as is. Set
            this flag to `False` if the cell state contains `TensorArray`s that
            are not amenable to reordering.
          **kwargs: Dict, other keyword arguments for parent class.

        Raises:
          TypeError: if `cell` is not an instance of `RNNCell`,
            or `output_layer` is not an instance of `tf.keras.layers.Layer`.
        """
        rnn_cell_impl.assert_like_rnncell("cell", cell)  # pylint: disable=protected-access
        if (output_layer is not None
                and not isinstance(output_layer, tf.keras.layers.Layer)):
            raise TypeError("output_layer must be a Layer, received: %s" %
                            type(output_layer))
        self._cell = cell
        self._output_layer = output_layer
        self._reorder_tensor_arrays = reorder_tensor_arrays

        self._start_tokens = None
        self._end_token = None
        self._batch_size = None
        self._beam_width = beam_width
        self._length_penalty_weight = length_penalty_weight
        self._coverage_penalty_weight = coverage_penalty_weight
        super(BeamSearchDecoderMixin, self).__init__(**kwargs)
Exemplo n.º 16
0
 def __init__(self, cell, output_size, W, activation=None, reuse=None):
     """Create a cell with output projection.
     Args:
       cell: an RNNCell, a projection to output_size is added to it.
       output_size: integer, the size of the output after projection.
       activation: (optional) an optional activation function.
       reuse: (optional) Python boolean describing whether to reuse variables
         in an existing scope.  If not `True`, and the existing scope already has
         the given variables, an error is raised.
     Raises:
       TypeError: if cell is not an RNNCell.
       ValueError: if output_size is not positive.
     """
     super(MyOutputProjectionWrapper, self).__init__(_reuse=reuse)
     rnn_cell_impl.assert_like_rnncell("cell", cell)
     if output_size < 1:
         raise ValueError("Parameter output_size must be > 0: %d." %
                          output_size)
     self._cell = cell
     self._output_size = output_size
     self._activation = activation
     self._W = W
Exemplo n.º 17
0
    def __init__(self,
                 cell,
                 helper,
                 initial_state,
                 encoder_outputs,
                 turn_points,
                 output_layer=None,
                 aux_hidden_state=None):
        """Initialize BasicDecoder.

    Args:
      cell: An `RNNCell` instance.
      helper: A `Helper` instance.
      initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
        The initial state of the RNNCell.
      encoder_outputs: the output of the encoder
      turn_points: points where conversations switch party
      output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
        `tf.layers.Dense`.  Optional layer to apply to the RNN output prior
        to storing the result or sampling.
      aux_hidden_state: hidden embeddings of context information
    Raises:
      TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
    """
        rnn_cell_impl.assert_like_rnncell("", cell)
        if not isinstance(helper, helper_py.Helper):
            raise TypeError("helper must be a Helper, received: %s" %
                            type(helper))
        if (output_layer is not None
                and not isinstance(output_layer, layers_base.Layer)):
            raise TypeError("output_layer must be a Layer, received: %s" %
                            type(output_layer))
        self._cell = cell
        self._helper = helper
        self._initial_state = initial_state
        self._output_layer = output_layer
        self.encoder_outputs = encoder_outputs
        self.turn_points = turn_points
        self._aux_hidden_state = aux_hidden_state
Exemplo n.º 18
0
  def __init__(self, cell, output_size, activation=None, reuse=None):
    """Create a cell with output projection.

    Args:
      cell: an RNNCell, a projection to output_size is added to it.
      output_size: integer, the size of the output after projection.
      activation: (optional) an optional activation function.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.

    Raises:
      TypeError: if cell is not an RNNCell.
      ValueError: if output_size is not positive.
    """
    super(OutputProjectionWrapper, self).__init__(_reuse=reuse)
    rnn_cell_impl.assert_like_rnncell("cell", cell)
    if output_size < 1:
      raise ValueError("Parameter output_size must be > 0: %d." % output_size)
    self._cell = cell
    self._output_size = output_size
    self._activation = activation
    self._linear = None
  def __init__(self,
               cell,
               embedding,
               start_tokens,
               end_token,
               initial_state,
               beam_width,
               output_layer=None,
               length_penalty_weight=0.0):
    """Initialize the BeamSearchDecoder.

    Args:
      cell: An `RNNCell` instance.
      embedding: A callable that takes a vector tensor of `ids` (argmax ids),
        or the `params` argument for `embedding_lookup`.
      start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
      end_token: `int32` scalar, the token that marks end of decoding.
      initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
      beam_width:  Python integer, the number of beams.
      output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
        `tf.layers.Dense`.  Optional layer to apply to the RNN output prior
        to storing the result or sampling.
      length_penalty_weight: Float weight to penalize length. Disabled with 0.0.

    Raises:
      TypeError: if `cell` is not an instance of `RNNCell`,
        or `output_layer` is not an instance of `tf.layers.Layer`.
      ValueError: If `start_tokens` is not a vector or
        `end_token` is not a scalar.
    """
    rnn_cell_impl.assert_like_rnncell("cell",cell)
    if (output_layer is not None and
        not isinstance(output_layer, layers_base.Layer)):
      raise TypeError(
          "output_layer must be a Layer, received: %s" % type(output_layer))
    self._cell = cell
    self._output_layer = output_layer

    if callable(embedding):
      self._embedding_fn = embedding
    else:
      self._embedding_fn = (
          lambda ids: embedding_ops.embedding_lookup(embedding, ids))

    self._start_tokens = ops.convert_to_tensor(
        start_tokens, dtype=dtypes.int32, name="start_tokens")
    if self._start_tokens.get_shape().ndims != 1:
      raise ValueError("start_tokens must be a vector")
    self._end_token = ops.convert_to_tensor(
        end_token, dtype=dtypes.int32, name="end_token")
    if self._end_token.get_shape().ndims != 0:
      raise ValueError("end_token must be a scalar")

    self._batch_size = array_ops.size(start_tokens)
    self._beam_width = beam_width
    self._length_penalty_weight = length_penalty_weight
    self._initial_cell_state = nest.map_structure(
        self._maybe_split_batch_beams, initial_state, self._cell.state_size)
    self._start_tokens = array_ops.tile(
        array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
    self._start_inputs = self._embedding_fn(self._start_tokens)

    self._finished = array_ops.one_hot(
        array_ops.zeros([self._batch_size], dtype=dtypes.int32),
        depth=self._beam_width,
        on_value=False,
        off_value=True,
        dtype=dtypes.bool)
    def __init__(self,
                 cell,
                 embedding,
                 start_tokens,
                 end_token,
                 initial_state,
                 beam_width,
                 batch,
                 no_of_users,
                 num_attr,
                 mec_attr,
                 user_attr,
                 vocab_size,
                 output_layer=None,
                 length_penalty_weight=0.0,
                 reorder_tensor_arrays=True):
        """Initialize the BeamSearchDecoder.

    Args:
      cell: An `RNNCell` instance.
      embedding: A callable that takes a vector tensor of `ids` (argmax ids),
        or the `params` argument for `embedding_lookup`.
      start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
      end_token: `int32` scalar, the token that marks end of decoding.
      initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
      beam_width:  Python integer, the number of beams.
      output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
        `tf.layers.Dense`.  Optional layer to apply to the RNN output prior
        to storing the result or sampling.
      length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
      reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
        state will be reordered according to the beam search path. If the
        `TensorArray` can be reordered, the stacked form will be returned.
        Otherwise, the `TensorArray` will be returned as is. Set this flag to
        `False` if the cell state contains `TensorArray`s that are not amenable
        to reordering.

    Raises:
      TypeError: if `cell` is not an instance of `RNNCell`,
        or `output_layer` is not an instance of `tf.layers.Layer`.
      ValueError: If `start_tokens` is not a vector or
        `end_token` is not a scalar.
    """
        rnn_cell_impl.assert_like_rnncell("cell", cell)  # pylint: disable=protected-access
        if (output_layer is not None
                and not isinstance(output_layer, layers_base.Layer)):
            raise TypeError("output_layer must be a Layer, received: %s" %
                            type(output_layer))
        self._cell = cell
        self._output_layer = output_layer
        self._reorder_tensor_arrays = reorder_tensor_arrays
        self._batch = batch
        self._beam = beam_width
        self._users = no_of_users
        self._num_attr = num_attr
        self._mec_attr = mec_attr
        self._user_attr = user_attr
        self._vocab_size = vocab_size

        if callable(embedding):
            self._embedding_fn = embedding
            #print("true")
        else:
            self._embedding_fn = (
                lambda ids: embedding_ops.embedding_lookup(embedding, ids))

        self._start_tokens = ops.convert_to_tensor(start_tokens,
                                                   dtype=dtypes.int32,
                                                   name="start_tokens")
        if self._start_tokens.get_shape().ndims != 1:
            raise ValueError("start_tokens must be a vector")
        self._end_token = ops.convert_to_tensor(end_token,
                                                dtype=dtypes.int32,
                                                name="end_token")
        if self._end_token.get_shape().ndims != 0:
            raise ValueError("end_token must be a scalar")

        self._batch_size = array_ops.size(start_tokens)
        self._beam_width = beam_width
        self._length_penalty_weight = length_penalty_weight
        self._initial_cell_state = nest.map_structure(
            self._maybe_split_batch_beams, initial_state,
            self._cell.state_size)
        self._start_tokens = array_ops.tile(
            array_ops.expand_dims(self._start_tokens, 1),
            [1, self._beam_width])
        self._start_inputs = self._embedding_fn(self._start_tokens, 0)

        self._finished = array_ops.one_hot(array_ops.zeros([self._batch_size],
                                                           dtype=dtypes.int32),
                                           depth=self._beam_width,
                                           on_value=False,
                                           off_value=True,
                                           dtype=dtypes.bool)
Exemplo n.º 21
0
def bidirectional_dynamic_rnn(cell_fw,
                              cell_bw,
                              inputs,
                              sequence_length=None,
                              initial_state_fw=None,
                              initial_state_bw=None,
                              dtype=None,
                              parallel_iterations=None,
                              swap_memory=False,
                              time_major=False,
                              scope=None):
    """Creates a dynamic version of bidirectional recurrent neural network.

  Takes input and builds independent forward and backward RNNs. The input_size
  of forward and backward cell must match. The initial state for both directions
  is zero by default (but can be set optionally) and no intermediate states are
  ever returned -- the network is fully unrolled for the given (passed in)
  length(s) of the sequence(s) or completely unrolled if length(s) is not
  given.

  Args:
    cell_fw: An instance of RNNCell, to be used for forward direction.
    cell_bw: An instance of RNNCell, to be used for backward direction.
    inputs: The RNN inputs.
      If time_major == False (default), this must be a tensor of shape:
        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
      If time_major == True, this must be a tensor of shape: `[max_time,
        batch_size, ...]`, or a nested tuple of such elements.
    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
      containing the actual lengths for each of the sequences in the batch. If
      not provided, all batch entries are assumed to be full sequences; and time
      reversal is applied from time `0` to `max_time` for each sequence.
    initial_state_fw: (optional) An initial state for the forward RNN. This must
      be a tensor of appropriate type and shape `[batch_size,
      cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a
      tuple of tensors having shapes `[batch_size, s] for s in
      cell_fw.state_size`.
    initial_state_bw: (optional) Same as for `initial_state_fw`, but using the
      corresponding properties of `cell_bw`.
    dtype: (optional) The data type for the initial states and expected output.
      Required if initial_states are not provided or RNN states have a
      heterogeneous dtype.
    parallel_iterations: (Default: 32).  The number of iterations to run in
      parallel.  Those operations which do not have any temporal dependency and
      can be run in parallel, will be.  This parameter trades off time for
      space.  Values >> 1 use more memory but take less time, while smaller
      values use less memory but computations take longer.
    swap_memory: Transparently swap the tensors produced in forward inference
      but needed for back prop from GPU to CPU.  This allows training RNNs which
      would typically not fit on a single GPU, with very minimal (or no)
      performance penalty.
    time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
      these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
      these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
      `time_major = True` is a bit more efficient because it avoids transposes
      at the beginning and end of the RNN calculation.  However, most TensorFlow
      data is batch-major, so by default this function accepts input and emits
      output in batch-major form.
    scope: VariableScope for the created subgraph; defaults to
      "bidirectional_rnn"

  Returns:
    A tuple (outputs, output_states) where:
      outputs: A tuple (output_fw, output_bw) containing the forward and
        the backward rnn output `Tensor`.
        If time_major == False (default),
          output_fw will be a `Tensor` shaped:
          `[batch_size, max_time, cell_fw.output_size]`
          and output_bw will be a `Tensor` shaped:
          `[batch_size, max_time, cell_bw.output_size]`.
        If time_major == True,
          output_fw will be a `Tensor` shaped:
          `[max_time, batch_size, cell_fw.output_size]`
          and output_bw will be a `Tensor` shaped:
          `[max_time, batch_size, cell_bw.output_size]`.
        It returns a tuple instead of a single concatenated `Tensor`, unlike
        in the `bidirectional_rnn`. If the concatenated one is preferred,
        the forward and backward outputs can be concatenated as
        `tf.concat(outputs, 2)`.
      output_states: A tuple (output_state_fw, output_state_bw) containing
        the forward and the backward final states of bidirectional rnn.

  Raises:
    TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
  """
    rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
    rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)

    with vs.variable_scope(scope or "bidirectional_rnn"):
        # Forward direction
        with vs.variable_scope("fw") as fw_scope:
            output_fw, output_state_fw = dynamic_rnn(
                cell=cell_fw,
                inputs=inputs,
                sequence_length=sequence_length,
                initial_state=initial_state_fw,
                dtype=dtype,
                parallel_iterations=parallel_iterations,
                swap_memory=swap_memory,
                time_major=time_major,
                scope=fw_scope)

        # Backward direction
        if not time_major:
            time_axis = 1
            batch_axis = 0
        else:
            time_axis = 0
            batch_axis = 1

        def _reverse(input_, seq_lengths, seq_axis, batch_axis):
            if seq_lengths is not None:
                return array_ops.reverse_sequence(input=input_,
                                                  seq_lengths=seq_lengths,
                                                  seq_axis=seq_axis,
                                                  batch_axis=batch_axis)
            else:
                return array_ops.reverse(input_, axis=[seq_axis])

        with vs.variable_scope("bw") as bw_scope:

            def _map_reverse(inp):
                return _reverse(inp,
                                seq_lengths=sequence_length,
                                seq_axis=time_axis,
                                batch_axis=batch_axis)

            inputs_reverse = nest.map_structure(_map_reverse, inputs)
            tmp, output_state_bw = dynamic_rnn(
                cell=cell_bw,
                inputs=inputs_reverse,
                sequence_length=sequence_length,
                initial_state=initial_state_bw,
                dtype=dtype,
                parallel_iterations=parallel_iterations,
                swap_memory=swap_memory,
                time_major=time_major,
                scope=bw_scope)

    output_bw = _reverse(tmp,
                         seq_lengths=sequence_length,
                         seq_axis=time_axis,
                         batch_axis=batch_axis)

    outputs = (output_fw, output_bw)
    output_states = (output_state_fw, output_state_bw)

    return (outputs, output_states)
Exemplo n.º 22
0
def dynamic_rnn(cell,
                inputs,
                sequence_length=None,
                initial_state=None,
                dtype=None,
                parallel_iterations=None,
                swap_memory=False,
                time_major=True,
                scope=None):
    """Creates a recurrent neural network specified by RNNCell `cell`.

  Performs fully dynamic unrolling of `inputs`.

  Example:

  ```python
  # create a BasicRNNCell
  rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)

  # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

  # defining initial state
  initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)

  # 'state' is a tensor of shape [batch_size, cell_state_size]
  outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                     initial_state=initial_state,
                                     dtype=tf.float32)
  ```

  ```python
  # create 2 LSTMCells
  rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]

  # create a RNN cell composed sequentially of a number of RNNCells
  multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)

  # 'outputs' is a tensor of shape [batch_size, max_time, 256]
  # 'state' is a N-tuple where N is the number of LSTMCells containing a
  # tf.contrib.rnn.LSTMStateTuple for each cell
  outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
                                     inputs=data,
                                     dtype=tf.float32)
  ```


  Args:
    cell: An instance of RNNCell.
    inputs: The RNN inputs.
      If `time_major == False` (default), this must be a `Tensor` of shape:
        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
      If `time_major == True`, this must be a `Tensor` of shape: `[max_time,
        batch_size, ...]`, or a nested tuple of such elements. This may also be
        a (possibly nested) tuple of Tensors satisfying this property.  The
        first two dimensions must match across all the inputs, but otherwise the
        ranks and other shape components may differ. In this case, input to
        `cell` at each time-step will replicate the structure of these tuples,
        except for the time dimension (from which the time is taken). The input
        to `cell` at each time step will be a `Tensor` or (possibly nested)
        tuple of Tensors each with dimensions `[batch_size, ...]`.
    sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used
      to copy-through state and zero-out outputs when past a batch element's
      sequence length.  So it's more for performance than correctness.
    initial_state: (optional) An initial state for the RNN. If `cell.state_size`
      is an integer, this must be a `Tensor` of appropriate type and shape
      `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
      should be a tuple of tensors having shapes `[batch_size, s] for s in
      cell.state_size`.
    dtype: (optional) The data type for the initial state and expected output.
      Required if initial_state is not provided or RNN state has a heterogeneous
      dtype.
    parallel_iterations: (Default: 32).  The number of iterations to run in
      parallel.  Those operations which do not have any temporal dependency and
      can be run in parallel, will be.  This parameter trades off time for
      space.  Values >> 1 use more memory but take less time, while smaller
      values use less memory but computations take longer.
    swap_memory: Transparently swap the tensors produced in forward inference
      but needed for back prop from GPU to CPU.  This allows training RNNs which
      would typically not fit on a single GPU, with very minimal (or no)
      performance penalty.
    time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
      these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
      these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
      `time_major = True` is a bit more efficient because it avoids transposes
      at the beginning and end of the RNN calculation.  However, most TensorFlow
      data is batch-major, so by default this function accepts input and emits
      output in batch-major form.
    scope: VariableScope for the created subgraph; defaults to "rnn".

  Returns:
    A pair (outputs, state) where:

    outputs: The RNN output `Tensor`.

      If time_major == False (default), this will be a `Tensor` shaped:
        `[batch_size, max_time, cell.output_size]`.

      If time_major == True, this will be a `Tensor` shaped:
        `[max_time, batch_size, cell.output_size]`.

      Note, if `cell.output_size` is a (possibly nested) tuple of integers
      or `TensorShape` objects, then `outputs` will be a tuple having the
      same structure as `cell.output_size`, containing Tensors having shapes
      corresponding to the shape data in `cell.output_size`.

    state: The final state.  If `cell.state_size` is an int, this
      will be shaped `[batch_size, cell.state_size]`.  If it is a
      `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
      If it is a (possibly nested) tuple of ints or `TensorShape`, this will
      be a tuple having the corresponding shapes. If cells are `LSTMCells`
      `state` will be a tuple containing a `LSTMStateTuple` for each cell.

  Raises:
    TypeError: If `cell` is not an instance of RNNCell.
    ValueError: If inputs is None or an empty list.
    RuntimeError: If not using control flow v2.
  """

    # Currently only support time_major == True case.
    assert time_major

    # TODO(b/123051275): We need to check if the cells are TfLiteLSTMCells or
    # TfLiteRNNCells.
    rnn_cell_impl.assert_like_rnncell("cell", cell)

    if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
        raise RuntimeError("OpHint dynamic rnn only supports control flow v2.")

    parent_first_child_input = [{
        "parent_ophint_input_index": 0,
        "first_child_ophint_input_index": 0
    }]
    parent_last_child_output = [{
        "parent_output_index": 0,
        # For LstmCell, the index is 2.
        # For RnnCell, the index is 1.
        # So we use -1 meaning it's the last one.
        "child_output_index": -1
    }]
    internal_children_input_output = [{
        "child_input_index": 0,
        # For LstmCell, the index is 2.
        # For RnnCell, the index is 1.
        # So we use -1 meaning it's the last one.
        "child_output_index": -1
    }]
    inputs_outputs_mappings = {
        "parent_first_child_input": parent_first_child_input,
        "parent_last_child_output": parent_last_child_output,
        "internal_children_input_output": internal_children_input_output
    }
    tflite_wrapper = op_hint.OpHint(
        "TfLiteDynamicRnn",
        level=2,
        children_inputs_mappings=inputs_outputs_mappings)
    with vs.variable_scope(scope or "rnn") as varscope:
        # Create a new scope in which the caching device is either
        # determined by the parent scope, or is set to place the cached
        # Variable using the same placement as for the rest of the RNN.
        if _should_cache():
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        inputs = tflite_wrapper.add_input(inputs,
                                          name="input",
                                          index_override=0)

        # By default, time_major==False and inputs are batch-major: shaped
        #   [batch, time, depth]
        # For internal calculations, we transpose to [time, batch, depth]
        flat_input = nest.flatten(inputs)

        if not time_major:
            # (batch, time, depth) => (time, batch, depth)
            flat_input = [
                ops.convert_to_tensor(input_) for input_ in flat_input
            ]
            flat_input = tuple(
                _transpose_batch_time(input_) for input_ in flat_input)

        parallel_iterations = parallel_iterations or 32
        if sequence_length is not None:
            sequence_length = math_ops.cast(sequence_length, dtypes.int32)
            if sequence_length.shape.rank not in (None, 1):
                raise ValueError(
                    "sequence_length must be a vector of length batch_size, "
                    "but saw shape: %s" % sequence_length.shape)
            sequence_length = array_ops.identity(  # Just to find it in the graph.
                sequence_length,
                name="sequence_length")

        batch_size = _best_effort_input_batch_size(flat_input)

        if initial_state is not None:
            state = initial_state
        else:
            if not dtype:
                raise ValueError(
                    "If there is no initial_state, you must give a dtype.")
            if getattr(cell, "get_initial_state", None) is not None:
                state = cell.get_initial_state(inputs=None,
                                               batch_size=batch_size,
                                               dtype=dtype)
            else:
                state = cell.zero_state(batch_size, dtype)

        def _assert_has_shape(x, shape):
            x_shape = array_ops.shape(x)
            packed_shape = array_ops.stack(shape)
            return control_flow_ops.Assert(
                math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [
                    "Expected shape for Tensor %s is " % x.name, packed_shape,
                    " but saw shape: ", x_shape
                ])

        if not context.executing_eagerly() and sequence_length is not None:
            # Perform some shape validation
            with ops.control_dependencies(
                [_assert_has_shape(sequence_length, [batch_size])]):
                sequence_length = array_ops.identity(sequence_length,
                                                     name="CheckSeqLen")

        inputs = nest.pack_sequence_as(structure=inputs,
                                       flat_sequence=flat_input)

        outputs, final_state = _dynamic_rnn_loop(
            cell,
            inputs,
            state,
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory,
            sequence_length=sequence_length,
            dtype=dtype)

        # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
        # If we are performing batch-major calculations, transpose output back
        # to shape [batch, time, depth]
        if not time_major:
            # (time, batch, depth) => (batch, time, depth)
            outputs = nest.map_structure(_transpose_batch_time, outputs)
        outputs = tflite_wrapper.add_output(outputs, name="outputs")

        return outputs, final_state
Exemplo n.º 23
0
def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None):
    """
    raw_rnn adapted from the original tensorflow implementation
    (https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py)
    to emit arbitrarily nested states for each time step (concatenated along the time axis)
    in addition to the outputs at each timestep and the final state

    returns (
        states for all timesteps,
        outputs for all timesteps,
        final cell state,
    )
    """
    if assert_like_rnncell("Raw rnn cell", cell):
        raise TypeError("cell must be an instance of RNNCell")
    if not callable(loop_fn):
        raise TypeError("loop_fn must be a callable")

    parallel_iterations = parallel_iterations or 32

    # Create a new scope in which the caching device is either
    # determined by the parent scope, or is set to place the cached
    # Variable using the same placement as for the rest of the RNN.
    with vs.variable_scope(scope or "rnn") as varscope:
        in_graph_mode = not context.executing_eagerly()
        if in_graph_mode:
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        time = constant_op.constant(0, dtype=dtypes.int32)
        (elements_finished, next_input, initial_state, emit_structure,
         init_loop_state) = loop_fn(time, None, None, None)
        flat_input = nest.flatten(next_input)

        # Need a surrogate loop state for the while_loop if none is available.
        loop_state = (init_loop_state if init_loop_state is not None
                      else constant_op.constant(0, dtype=dtypes.int32))

        input_shape = [input_.get_shape() for input_ in flat_input]
        static_batch_size = input_shape[0][0]

        for input_shape_i in input_shape:
            # Static verification that batch sizes all match
            static_batch_size.merge_with(input_shape_i[0])

        batch_size = static_batch_size.value
        const_batch_size = batch_size
        if batch_size is None:
            batch_size = array_ops.shape(flat_input[0])[0]

        nest.assert_same_structure(initial_state, cell.state_size)
        state = initial_state
        flat_state = nest.flatten(state)
        flat_state = [ops.convert_to_tensor(s) for s in flat_state]
        state = nest.pack_sequence_as(structure=state,
                                      flat_sequence=flat_state)

        if emit_structure is not None:
            flat_emit_structure = nest.flatten(emit_structure)
            flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
                              array_ops.shape(emit) for emit in flat_emit_structure]
            flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
        else:
            emit_structure = cell.output_size
            flat_emit_size = nest.flatten(emit_structure)
            flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

        flat_state_size = [s.shape if s.shape.is_fully_defined() else
                           array_ops.shape(s) for s in flat_state]
        flat_state_dtypes = [s.dtype for s in flat_state]

        flat_emit_ta = [
            tensor_array_ops.TensorArray(
                dtype=dtype_i,
                dynamic_size=True,
                element_shape=(tensor_shape.TensorShape([const_batch_size])
                               .concatenate(_maybe_tensor_shape_from_tensor(size_i))),
                size=0,
                name="rnn_output_%d" % i
            )
            for i, (dtype_i, size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
        ]
        emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_emit_ta)
        flat_zero_emit = [
            array_ops.zeros(_concat(batch_size, size_i), dtype_i)
            for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)]

        zero_emit = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_zero_emit)

        flat_state_ta = [
            tensor_array_ops.TensorArray(
                dtype=dtype_i,
                dynamic_size=True,
                element_shape=(tensor_shape.TensorShape([const_batch_size])
                               .concatenate(_maybe_tensor_shape_from_tensor(size_i))),
                size=0,
                name="rnn_state_%d" % i
            )
            for i, (dtype_i, size_i) in enumerate(zip(flat_state_dtypes, flat_state_size))
        ]
        state_ta = nest.pack_sequence_as(structure=state, flat_sequence=flat_state_ta)

        def condition(unused_time, elements_finished, *_):
            return math_ops.logical_not(math_ops.reduce_all(elements_finished))

        def body(time, elements_finished, current_input, state_ta, emit_ta, state, loop_state):
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output,
             next_loop_state) = loop_fn(next_time, next_output, cell_state, loop_state)

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                def copy_fn(cur_i, cand_i):
                    # TensorArray and scalar get passed through.
                    if isinstance(cur_i, tensor_array_ops.TensorArray):
                        return cand_i
                    if cur_i.shape.ndims == 0:
                        return cand_i
                    # Otherwise propagate the old or the new value.
                    with ops.colocate_with(cand_i):
                        return array_ops.where(elements_finished, cur_i, cand_i)
                return nest.map_structure(copy_fn, current, candidate)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)
            state_ta = nest.map_structure(lambda ta, state: ta.write(time, state), state_ta, next_state)

            elements_finished = math_ops.logical_or(elements_finished, next_finished)

            return (next_time, elements_finished, next_input, state_ta,
                    emit_ta, next_state, loop_state)

        returned = control_flow_ops.while_loop(
            condition, body, loop_vars=[
                time, elements_finished, next_input, state_ta,
                emit_ta, state, loop_state],
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory
        )

        (state_ta, emit_ta, final_state, final_loop_state) = returned[-4:]

        flat_states = nest.flatten(state_ta)
        flat_states = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states]
        states = nest.pack_sequence_as(structure=state_ta, flat_sequence=flat_states)

        flat_outputs = nest.flatten(emit_ta)
        flat_outputs = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs]
        outputs = nest.pack_sequence_as(structure=emit_ta, flat_sequence=flat_outputs)

        return (states, outputs, final_state)
Exemplo n.º 24
0
    def __init__(self,
                 cell,
                 embedding,
                 start_tokens,
                 end_token,
                 initial_state,
                 beam_width,
                 output_layer=None,
                 length_penalty_weight=0.0,
                 coverage_penalty_weight=0.0,
                 reorder_tensor_arrays=True,
                 skip_tokens_decoding=None,
                 shrink_vocab=0,
                 start_token_logits=None):
        """Initialize the BeamSearchDecoder.

        Args:
                cell: An `RNNCell` instance.
                embedding: A callable that takes a vector tensor of `ids` (argmax ids),
                        or the `params` argument for `embedding_lookup`.
                start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
                end_token: `int32` scalar, the token that marks end of decoding.
                initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
                beam_width:  Python integer, the number of beams.
                output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
                        `tf.layers.Dense`.  Optional layer to apply to the RNN output prior
                        to storing the result or sampling.
                length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
                coverage_penalty_weight: Float weight to penalize the coverage of source
                        sentence. Disabled with 0.0.
                reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
                        state will be reordered according to the beam search path. If the
                        `TensorArray` can be reordered, the stacked form will be returned.
                        Otherwise, the `TensorArray` will be returned as is. Set this flag to
                        `False` if the cell state contains `TensorArray`s that are not amenable
                        to reordering.
                skip_tokens_decoding: A list of tokens that should be skipped while decoding.
                        Defaults to None, that is, not skipping any tokens.
                shrink_vocab: Use only top 'N' tokens while decoding. Disabled with 0
                start_token_logits: Logits for the start tokens.  Used if _GO tokens are not
                        used. Defaults to None.

        Raises:
                TypeError: if `cell` is not an instance of `RNNCell`,
                        or `output_layer` is not an instance of `tf.layers.Layer`.
                ValueError: If `start_tokens` is not a vector or
                        `end_token` is not a scalar.
                ValueError: If `start token logits` not provided and `use_go_tokens`
                        is set to be False.
        """
        rnn_cell_impl.assert_like_rnncell(
            "cell", cell)  # pylint: disable=protected-access
        if (output_layer is not None and not isinstance(output_layer, layers_base.Layer)):
            raise TypeError(
                "output_layer must be a Layer, received: %s" % type(output_layer))
        self._cell = cell
        self._output_layer = output_layer
        self._reorder_tensor_arrays = reorder_tensor_arrays
        self._use_go_tokens = True
        if not self._use_go_tokens and start_token_logits is None:
            raise ValueError(
                "start token logits must be provided if use_go_tokens is False")
        
        if callable(embedding):
            self._embedding_fn = embedding
        else:
            self._embedding_fn = (
                lambda ids: embedding_ops.embedding_lookup(embedding, ids))

        if self._use_go_tokens:
            self._start_tokens = ops.convert_to_tensor(
                start_tokens, dtype=dtypes.int32, name="start_tokens")
            if self._start_tokens.get_shape().ndims != 1:
                raise ValueError("start_tokens must be a vector")

        self._end_token = ops.convert_to_tensor(
            end_token, dtype=dtypes.int32, name="end_token")
        if self._end_token.get_shape().ndims != 0:
            raise ValueError("end_token must be a scalar")

        self._beam_width = beam_width
        self._length_penalty_weight = length_penalty_weight
        self._coverage_penalty_weight = coverage_penalty_weight
        self._initial_cell_state = nest.map_structure(
            self._maybe_split_batch_beams, initial_state, self._cell.state_size)

        if self._use_go_tokens:
			self._batch_size = array_ops.size(start_tokens)
            self._start_tokens = array_ops.tile(
                array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
        else:
    def __init__(self,
                 cell,
                 attention_mechanism,
                 attention_layer_size=None,
                 alignment_history=False,
                 cell_input_fn=None,
                 output_attention=True,
                 initial_cell_state=None,
                 name=None,
                 attention_layer=None):

        super(AttentionWrapper, self).__init__(name=name)
        rnn_cell_impl.assert_like_rnncell("cell", cell)
        if isinstance(attention_mechanism, (list, tuple)):
            self._is_multi = True
            attention_mechanisms = attention_mechanism
            for attention_mechanism in attention_mechanisms:
                if not isinstance(attention_mechanism, AttentionMechanism):
                    raise TypeError(
                        "attention_mechanism must contain only instances of "
                        "AttentionMechanism, saw type: %s" %
                        type(attention_mechanism).__name__)
        else:
            self._is_multi = False
            if not isinstance(attention_mechanism, AttentionMechanism):
                raise TypeError(
                    "attention_mechanism must be an AttentionMechanism or list of "
                    "multiple AttentionMechanism instances, saw type: %s" %
                    type(attention_mechanism).__name__)
            attention_mechanisms = (attention_mechanism, )

        if cell_input_fn is None:
            cell_input_fn = (lambda inputs, attention: array_ops.concat(
                [inputs, attention], -1))
        else:
            if not callable(cell_input_fn):
                raise TypeError(
                    "cell_input_fn must be callable, saw type: %s" %
                    type(cell_input_fn).__name__)

        if attention_layer_size is not None and attention_layer is not None:
            raise ValueError(
                "Only one of attention_layer_size and attention_layer "
                "should be set")

        if attention_layer_size is not None:
            attention_layer_sizes = tuple(attention_layer_size if isinstance(
                attention_layer_size, (list,
                                       tuple)) else (attention_layer_size, ))
            if len(attention_layer_sizes) != len(attention_mechanisms):
                raise ValueError(
                    "If provided, attention_layer_size must contain exactly one "
                    "integer per attention_mechanism, saw: %d vs %d" %
                    (len(attention_layer_sizes), len(attention_mechanisms)))
            self._attention_layers = tuple(
                layers_core.Dense(attention_layer_size,
                                  name="attention_layer",
                                  use_bias=False,
                                  dtype=attention_mechanisms[i].dtype) for i,
                attention_layer_size in enumerate(attention_layer_sizes))
            self._attention_layer_size = sum(attention_layer_sizes)
        elif attention_layer is not None:
            self._attention_layers = tuple(attention_layer if isinstance(
                attention_layer, (list, tuple)) else (attention_layer, ))
            if len(self._attention_layers) != len(attention_mechanisms):
                raise ValueError(
                    "If provided, attention_layer must contain exactly one "
                    "layer per attention_mechanism, saw: %d vs %d" %
                    (len(self._attention_layers), len(attention_mechanisms)))
            self._attention_layer_size = sum(
                tensor_shape.dimension_value(
                    layer.compute_output_shape([
                        None, cell.output_size + tensor_shape.dimension_value(
                            mechanism.values.shape[-1])
                    ])[-1]) for layer, mechanism in zip(
                        self._attention_layers, attention_mechanisms))
        else:
            self._attention_layers = None
            self._attention_layer_size = sum(
                tensor_shape.dimension_value(
                    attention_mechanism.values.shape[-1])
                for attention_mechanism in attention_mechanisms)

        self._cell = cell
        self._attention_mechanisms = attention_mechanisms
        self._cell_input_fn = cell_input_fn
        self._output_attention = output_attention
        self._alignment_history = alignment_history
        with ops.name_scope(name, "AttentionWrapperInit"):
            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (tensor_shape.dimension_value(
                    final_state_tensor.shape[0])
                                    or array_ops.shape(final_state_tensor)[0])
                error_message = (
                    "When constructing AttentionWrapper %s: " % self._base_name
                    + "Non-matching batch sizes between the memory "
                    "(encoder output) and initial_cell_state.  Are you using "
                    "the BeamSearchDecoder?  You may need to tile your initial state "
                    "via the tf.contrib.seq2seq.tile_batch function with argument "
                    "multiple=beam_width.")
                with ops.control_dependencies(
                        self._batch_size_checks(state_batch_size,
                                                error_message)):
                    self._initial_cell_state = nest.map_structure(
                        lambda s: array_ops.identity(
                            s, name="check_initial_cell_state"),
                        initial_cell_state)
    def __init__(
            self,
            cell,
            attention_mechanism,
            is_manual_attention,  # 추가된 argument
            manual_alignments,  # 추가된 argument
            attention_layer_size=None,
            alignment_history=False,
            cell_input_fn=None,
            output_attention=True,
            initial_cell_state=None,
            name=None):
        """Construct the `AttentionWrapper`.
        **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
        `AttentionWrapper`, then you must ensure that:
        - The encoder output has been tiled to `beam_width` via
          @{tf.contrib.seq2seq.tile_batch} (NOT `tf.tile`).
        - The `batch_size` argument passed to the `zero_state` method of this
          wrapper is equal to `true_batch_size * beam_width`.
        - The initial state created with `zero_state` above contains a
          `cell_state` value containing properly tiled final state from the
          encoder.
        An example:
        ```
        tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
            encoder_outputs, multiplier=beam_width)
        tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch(
            encoder_final_state, multiplier=beam_width)
        tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
            sequence_length, multiplier=beam_width)
        attention_mechanism = MyFavoriteAttentionMechanism(
            num_units=attention_depth,
            memory=tiled_inputs,
            memory_sequence_length=tiled_sequence_length)
        attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
        decoder_initial_state = attention_cell.zero_state(
            dtype, batch_size=true_batch_size * beam_width)
        decoder_initial_state = decoder_initial_state.clone(
            cell_state=tiled_encoder_final_state)
        ```
        Args:
          cell: An instance of `RNNCell`.
          attention_mechanism: A list of `AttentionMechanism` instances or a single
            instance.
          attention_layer_size: A list of Python integers or a single Python
            integer, the depth of the attention (output) layer(s). If None
            (default), use the context as attention at each time step. Otherwise,
            feed the context and cell output into the attention layer to generate
            attention at each time step. If attention_mechanism is a list,
            attention_layer_size must be a list of the same length.
          alignment_history: Python boolean, whether to store alignment history
            from all time steps in the final output state (currently stored as a
            time major `TensorArray` on which you must call `stack()`).
          cell_input_fn: (optional) A `callable`.  The default is:
            `lambda inputs, attention: tf.concat([inputs, attention], -1)`.
          output_attention: Python bool.  If `True` (default), the output at each
            time step is the attention value.  This is the behavior of Luong-style
            attention mechanisms.  If `False`, the output at each time step is
            the output of `cell`.  This is the behavior of Bhadanau-style
            attention mechanisms.  In both cases, the `attention` tensor is
            propagated to the next time step via the state and is used there.
            This flag only controls whether the attention mechanism is propagated
            up to the next cell in an RNN stack or to the top RNN output.
          initial_cell_state: The initial state value to use for the cell when
            the user calls `zero_state()`.  Note that if this value is provided
            now, and the user uses a `batch_size` argument of `zero_state` which
            does not match the batch size of `initial_cell_state`, proper
            behavior is not guaranteed.
          name: Name to use when creating ops.
        Raises:
          TypeError: `attention_layer_size` is not None and (`attention_mechanism`
            is a list but `attention_layer_size` is not; or vice versa).
          ValueError: if `attention_layer_size` is not None, `attention_mechanism`
            is a list, and its length does not match that of `attention_layer_size`.
        """
        super(AttentionWrapper, self).__init__(name=name)

        self.is_manual_attention = is_manual_attention
        self.manual_alignments = manual_alignments

        rnn_cell_impl.assert_like_rnncell("cell", cell)
        if isinstance(attention_mechanism, (list, tuple)):
            self._is_multi = True

            attention_mechanisms = attention_mechanism
            for attention_mechanism in attention_mechanisms:
                if not isinstance(attention_mechanism, AttentionMechanism):
                    raise TypeError(
                        "attention_mechanism must contain only instances of "
                        "AttentionMechanism, saw type: %s" %
                        type(attention_mechanism).__name__)
        else:
            self._is_multi = False
            if not isinstance(attention_mechanism, AttentionMechanism):
                raise TypeError(
                    "attention_mechanism must be an AttentionMechanism or list of "
                    "multiple AttentionMechanism instances, saw type: %s" %
                    type(attention_mechanism).__name__)
            attention_mechanisms = (attention_mechanism, )

        if cell_input_fn is None:
            cell_input_fn = (
                lambda inputs, attention: tf.concat([inputs, attention], -1))
        else:
            if not callable(cell_input_fn):
                raise TypeError(
                    "cell_input_fn must be callable, saw type: %s" %
                    type(cell_input_fn).__name__)

        if attention_layer_size is not None:
            attention_layer_sizes = tuple(attention_layer_size if isinstance(
                attention_layer_size, (list,
                                       tuple)) else (attention_layer_size, ))
            if len(attention_layer_sizes) != len(attention_mechanisms):
                raise ValueError(
                    "If provided, attention_layer_size must contain exactly one "
                    "integer per attention_mechanism, saw: %d vs %d" %
                    (len(attention_layer_sizes), len(attention_mechanisms)))
            self._attention_layers = tuple(
                layers_core.Dense(attention_layer_size,
                                  name="attention_layer",
                                  use_bias=False,
                                  dtype=attention_mechanisms[i].dtype) for i,
                attention_layer_size in enumerate(attention_layer_sizes))
            self._attention_layer_size = sum(attention_layer_sizes)
        else:
            self._attention_layers = None
            self._attention_layer_size = sum(
                attention_mechanism.values.get_shape()[-1].value
                for attention_mechanism in attention_mechanisms)

        self._cell = cell
        self._attention_mechanisms = attention_mechanisms
        self._cell_input_fn = cell_input_fn
        self._output_attention = output_attention
        self._alignment_history = alignment_history
        with tf.name_scope(name, "AttentionWrapperInit"):
            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (final_state_tensor.shape[0].value
                                    or tf.shape(final_state_tensor)[0])
                error_message = (
                    "When constructing AttentionWrapper %s: " % self._base_name
                    + "Non-matching batch sizes between the memory "
                    "(encoder output) and initial_cell_state.  Are you using "
                    "the BeamSearchDecoder?  You may need to tile your initial state "
                    "via the tf.contrib.seq2seq.tile_batch function with argument "
                    "multiple=beam_width.")
                with tf.control_dependencies(
                        self._batch_size_checks(state_batch_size,
                                                error_message)):
                    self._initial_cell_state = nest.map_structure(
                        lambda s: tf.identity(s,
                                              name="check_initial_cell_state"),
                        initial_cell_state)
Exemplo n.º 27
0
def modified_static_rnn(cell,
               inputs,
               initial_state=None,
               dtype=None,
               sequence_length=None,
               scope=None):
  rnn_cell_impl.assert_like_rnncell("cell", cell)
  if not nest.is_sequence(inputs):
    raise TypeError("inputs must be a sequence")
  if not inputs:
    raise ValueError("inputs must not be empty")

  outputs = []
  states = []
  # Create a new scope in which the caching device is either
  # determined by the parent scope, or is set to place the cached
  # Variable using the same placement as for the rest of the RNN.
  with vs.variable_scope(scope or "rnn") as varscope:
    if _should_cache():
      if varscope.caching_device is None:
        varscope.set_caching_device(lambda op: op.device)

    # Obtain the first sequence of the input
    first_input = inputs
    while nest.is_sequence(first_input):
      first_input = first_input[0]

    # Temporarily avoid EmbeddingWrapper and seq2seq badness
    # TODO(lukaszkaiser): remove EmbeddingWrapper
    if first_input.get_shape().ndims != 1:

      input_shape = first_input.get_shape().with_rank_at_least(2)
      fixed_batch_size = input_shape[0]

      flat_inputs = nest.flatten(inputs)
      for flat_input in flat_inputs:
        input_shape = flat_input.get_shape().with_rank_at_least(2)
        batch_size, input_size = input_shape[0], input_shape[1:]
        fixed_batch_size.merge_with(batch_size)
        for i, size in enumerate(input_size):
          if size.value is None:
            raise ValueError(
                "Input size (dimension %d of inputs) must be accessible via "
                "shape inference, but saw value None." % i)
    else:
      fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]

    if fixed_batch_size.value:
      batch_size = fixed_batch_size.value
    else:
      batch_size = array_ops.shape(first_input)[0]
    if initial_state is not None:
      state = initial_state
    else:
      if not dtype:
        raise ValueError("If no initial_state is provided, "
                         "dtype must be specified")
      if getattr(cell, "get_initial_state", None) is not None:
        state = cell.get_initial_state(
            inputs=None, batch_size=batch_size, dtype=dtype)
      else:
        state = cell.zero_state(batch_size, dtype)

    if sequence_length is not None:  # Prepare variables
      sequence_length = ops.convert_to_tensor(
          sequence_length, name="sequence_length")
      if sequence_length.get_shape().ndims not in (None, 1):
        raise ValueError(
            "sequence_length must be a vector of length batch_size")

      def _create_zero_output(output_size):
        # convert int to TensorShape if necessary
        size = _concat(batch_size, output_size)
        output = array_ops.zeros(
            array_ops.stack(size), _infer_state_dtype(dtype, state))
        shape = _concat(fixed_batch_size.value, output_size, static=True)
        output.set_shape(tensor_shape.TensorShape(shape))
        return output

      output_size = cell.output_size
      flat_output_size = nest.flatten(output_size)
      flat_zero_output = tuple(
          _create_zero_output(size) for size in flat_output_size)
      zero_output = nest.pack_sequence_as(
          structure=output_size, flat_sequence=flat_zero_output)

      sequence_length = math_ops.to_int32(sequence_length)
      min_sequence_length = math_ops.reduce_min(sequence_length)
      max_sequence_length = math_ops.reduce_max(sequence_length)

    # Keras RNN cells only accept state as list, even if it's a single tensor.
    is_keras_rnn_cell = _is_keras_rnn_cell(cell)
    if is_keras_rnn_cell and not nest.is_sequence(state):
      state = [state]
    for time, input_ in enumerate(inputs):
      if time > 0:
        varscope.reuse_variables()
      # pylint: disable=cell-var-from-loop
      call_cell = lambda: cell(input_, state)
      # pylint: enable=cell-var-from-loop
      if sequence_length is not None:
        (output, state) = _rnn_step(
            time=time,
            sequence_length=sequence_length,
            min_sequence_length=min_sequence_length,
            max_sequence_length=max_sequence_length,
            zero_output=zero_output,
            state=state,
            call_cell=call_cell,
            state_size=cell.state_size)
      else:
        (output, state) = call_cell()
      outputs.append(output)
      states.append(state)
    # Keras RNN cells only return state as list, even if it's a single tensor.
    if is_keras_rnn_cell and len(state) == 1:
      state = state[0]

    return (outputs, state,states)
Exemplo n.º 28
0
  def __init__(self,
               cell,
               embedding,
               start_tokens,
               end_token,
               initial_state,
               beam_width,
               output_layer=None,
               length_penalty_weight=0.0,
               reorder_tensor_arrays=True):
    """Initialize the BeamSearchDecoder.

    Args:
      cell: An `RNNCell` instance.
      embedding: A callable that takes a vector tensor of `ids` (argmax ids),
        or the `params` argument for `embedding_lookup`.
      start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
      end_token: `int32` scalar, the token that marks end of decoding.
      initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
      beam_width:  Python integer, the number of beams.
      output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
        `tf.layers.Dense`.  Optional layer to apply to the RNN output prior
        to storing the result or sampling.
      length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
      reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
        state will be reordered according to the beam search path. If the
        `TensorArray` can be reordered, the stacked form will be returned.
        Otherwise, the `TensorArray` will be returned as is. Set this flag to
        `False` if the cell state contains `TensorArray`s that are not amenable
        to reordering.

    Raises:
      TypeError: if `cell` is not an instance of `RNNCell`,
        or `output_layer` is not an instance of `tf.layers.Layer`.
      ValueError: If `start_tokens` is not a vector or
        `end_token` is not a scalar.
    """
    rnn_cell_impl.assert_like_rnncell("cell", cell)  # pylint: disable=protected-access
    if (output_layer is not None and
        not isinstance(output_layer, layers_base.Layer)):
      raise TypeError(
          "output_layer must be a Layer, received: %s" % type(output_layer))
    self._cell = cell
    self._output_layer = output_layer
    self._reorder_tensor_arrays = reorder_tensor_arrays

    if callable(embedding):
      self._embedding_fn = embedding
    else:
      self._embedding_fn = (
          lambda ids: embedding_ops.embedding_lookup(embedding, ids))

    self._start_tokens = ops.convert_to_tensor(
        start_tokens, dtype=dtypes.int32, name="start_tokens")
    if self._start_tokens.get_shape().ndims != 1:
      raise ValueError("start_tokens must be a vector")
    self._end_token = ops.convert_to_tensor(
        end_token, dtype=dtypes.int32, name="end_token")
    if self._end_token.get_shape().ndims != 0:
      raise ValueError("end_token must be a scalar")

    self._batch_size = array_ops.size(start_tokens)
    self._beam_width = beam_width
    self._length_penalty_weight = length_penalty_weight
    self._initial_cell_state = nest.map_structure(
        self._maybe_split_batch_beams, initial_state, self._cell.state_size)
    self._start_tokens = array_ops.tile(
        array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
    self._start_inputs = self._embedding_fn(self._start_tokens)

    self._finished = array_ops.one_hot(
        array_ops.zeros([self._batch_size], dtype=dtypes.int32),
        depth=self._beam_width,
        on_value=False,
        off_value=True,
        dtype=dtypes.bool)
Exemplo n.º 29
0
def dynamic_rnn(cell,
                inputs,
                sequence_length=None,
                initial_state=None,
                dtype=None,
                parallel_iterations=None,
                swap_memory=False,
                time_major=True,
                scope=None):
  """Creates a recurrent neural network specified by RNNCell `cell`.

  Performs fully dynamic unrolling of `inputs`.

  Example:

  ```python
  # create a BasicRNNCell
  rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)

  # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

  # defining initial state
  initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)

  # 'state' is a tensor of shape [batch_size, cell_state_size]
  outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                     initial_state=initial_state,
                                     dtype=tf.float32)
  ```

  ```python
  # create 2 LSTMCells
  rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]

  # create a RNN cell composed sequentially of a number of RNNCells
  multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)

  # 'outputs' is a tensor of shape [batch_size, max_time, 256]
  # 'state' is a N-tuple where N is the number of LSTMCells containing a
  # tf.contrib.rnn.LSTMStateTuple for each cell
  outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
                                     inputs=data,
                                     dtype=tf.float32)
  ```


  Args:
    cell: An instance of RNNCell.
    inputs: The RNN inputs.
      If `time_major == False` (default), this must be a `Tensor` of shape:
        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
      If `time_major == True`, this must be a `Tensor` of shape: `[max_time,
        batch_size, ...]`, or a nested tuple of such elements. This may also be
        a (possibly nested) tuple of Tensors satisfying this property.  The
        first two dimensions must match across all the inputs, but otherwise the
        ranks and other shape components may differ. In this case, input to
        `cell` at each time-step will replicate the structure of these tuples,
        except for the time dimension (from which the time is taken). The input
        to `cell` at each time step will be a `Tensor` or (possibly nested)
        tuple of Tensors each with dimensions `[batch_size, ...]`.
    sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used
      to copy-through state and zero-out outputs when past a batch element's
      sequence length.  So it's more for performance than correctness.
    initial_state: (optional) An initial state for the RNN. If `cell.state_size`
      is an integer, this must be a `Tensor` of appropriate type and shape
      `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
      should be a tuple of tensors having shapes `[batch_size, s] for s in
      cell.state_size`.
    dtype: (optional) The data type for the initial state and expected output.
      Required if initial_state is not provided or RNN state has a heterogeneous
      dtype.
    parallel_iterations: (Default: 32).  The number of iterations to run in
      parallel.  Those operations which do not have any temporal dependency and
      can be run in parallel, will be.  This parameter trades off time for
      space.  Values >> 1 use more memory but take less time, while smaller
      values use less memory but computations take longer.
    swap_memory: Transparently swap the tensors produced in forward inference
      but needed for back prop from GPU to CPU.  This allows training RNNs which
      would typically not fit on a single GPU, with very minimal (or no)
      performance penalty.
    time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
      these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
      these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
      `time_major = True` is a bit more efficient because it avoids transposes
      at the beginning and end of the RNN calculation.  However, most TensorFlow
      data is batch-major, so by default this function accepts input and emits
      output in batch-major form.
    scope: VariableScope for the created subgraph; defaults to "rnn".

  Returns:
    A pair (outputs, state) where:

    outputs: The RNN output `Tensor`.

      If time_major == False (default), this will be a `Tensor` shaped:
        `[batch_size, max_time, cell.output_size]`.

      If time_major == True, this will be a `Tensor` shaped:
        `[max_time, batch_size, cell.output_size]`.

      Note, if `cell.output_size` is a (possibly nested) tuple of integers
      or `TensorShape` objects, then `outputs` will be a tuple having the
      same structure as `cell.output_size`, containing Tensors having shapes
      corresponding to the shape data in `cell.output_size`.

    state: The final state.  If `cell.state_size` is an int, this
      will be shaped `[batch_size, cell.state_size]`.  If it is a
      `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
      If it is a (possibly nested) tuple of ints or `TensorShape`, this will
      be a tuple having the corresponding shapes. If cells are `LSTMCells`
      `state` will be a tuple containing a `LSTMStateTuple` for each cell.

  Raises:
    TypeError: If `cell` is not an instance of RNNCell.
    ValueError: If inputs is None or an empty list.
    RuntimeError: If not using control flow v2.
  """

  # Currently only support time_major == True case.
  assert time_major

  # TODO(b/123051275): We need to check if the cells are TfLiteLSTMCells or
  # TfLiteRNNCells.
  rnn_cell_impl.assert_like_rnncell("cell", cell)

  if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
    raise RuntimeError("OpHint dynamic rnn only supports control flow v2.")

  parent_first_child_input = [{
      "parent_ophint_input_index": 0,
      "first_child_ophint_input_index": 0
  }]
  parent_last_child_output = [{
      "parent_output_index": 0,
      # For LstmCell, the index is 2.
      # For RnnCell, the index is 1.
      # So we use -1 meaning it's the last one.
      "child_output_index": -1
  }]
  internal_children_input_output = [{
      "child_input_index": 0,
      # For LstmCell, the index is 2.
      # For RnnCell, the index is 1.
      # So we use -1 meaning it's the last one.
      "child_output_index": -1
  }]
  inputs_outputs_mappings = {
      "parent_first_child_input": parent_first_child_input,
      "parent_last_child_output": parent_last_child_output,
      "internal_children_input_output": internal_children_input_output
  }
  tflite_wrapper = op_hint.OpHint(
      "TfLiteDynamicRnn",
      level=2,
      children_inputs_mappings=inputs_outputs_mappings)
  with vs.variable_scope(scope or "rnn") as varscope:
    # Create a new scope in which the caching device is either
    # determined by the parent scope, or is set to place the cached
    # Variable using the same placement as for the rest of the RNN.
    if _should_cache():
      if varscope.caching_device is None:
        varscope.set_caching_device(lambda op: op.device)

    inputs = tflite_wrapper.add_input(inputs, name="input", index_override=0)

    # By default, time_major==False and inputs are batch-major: shaped
    #   [batch, time, depth]
    # For internal calculations, we transpose to [time, batch, depth]
    flat_input = nest.flatten(inputs)

    if not time_major:
      # (batch, time, depth) => (time, batch, depth)
      flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
      flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)

    parallel_iterations = parallel_iterations or 32
    if sequence_length is not None:
      sequence_length = math_ops.to_int32(sequence_length)
      if sequence_length.get_shape().rank not in (None, 1):
        raise ValueError(
            "sequence_length must be a vector of length batch_size, "
            "but saw shape: %s" % sequence_length.get_shape())
      sequence_length = array_ops.identity(  # Just to find it in the graph.
          sequence_length,
          name="sequence_length")

    batch_size = _best_effort_input_batch_size(flat_input)

    if initial_state is not None:
      state = initial_state
    else:
      if not dtype:
        raise ValueError("If there is no initial_state, you must give a dtype.")
      if getattr(cell, "get_initial_state", None) is not None:
        state = cell.get_initial_state(
            inputs=None, batch_size=batch_size, dtype=dtype)
      else:
        state = cell.zero_state(batch_size, dtype)

    def _assert_has_shape(x, shape):
      x_shape = array_ops.shape(x)
      packed_shape = array_ops.stack(shape)
      return control_flow_ops.Assert(
          math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [
              "Expected shape for Tensor %s is " % x.name, packed_shape,
              " but saw shape: ", x_shape
          ])

    if not context.executing_eagerly() and sequence_length is not None:
      # Perform some shape validation
      with ops.control_dependencies(
          [_assert_has_shape(sequence_length, [batch_size])]):
        sequence_length = array_ops.identity(
            sequence_length, name="CheckSeqLen")

    inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)

    outputs, final_state = _dynamic_rnn_loop(
        cell,
        inputs,
        state,
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory,
        sequence_length=sequence_length,
        dtype=dtype)

    # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
    # If we are performing batch-major calculations, transpose output back
    # to shape [batch, time, depth]
    if not time_major:
      # (time, batch, depth) => (batch, time, depth)
      outputs = nest.map_structure(_transpose_batch_time, outputs)
    outputs = tflite_wrapper.add_output(outputs, name="outputs")

    return outputs, final_state
Exemplo n.º 30
0
def bidirectional_dynamic_rnn(cell_fw,
                              cell_bw,
                              inputs,
                              sequence_length=None,
                              initial_state_fw=None,
                              initial_state_bw=None,
                              dtype=None,
                              parallel_iterations=None,
                              swap_memory=False,
                              time_major=False,
                              scope=None):
  """Creates a dynamic version of bidirectional recurrent neural network.

  Takes input and builds independent forward and backward RNNs. The input_size
  of forward and backward cell must match. The initial state for both directions
  is zero by default (but can be set optionally) and no intermediate states are
  ever returned -- the network is fully unrolled for the given (passed in)
  length(s) of the sequence(s) or completely unrolled if length(s) is not
  given.

  Args:
    cell_fw: An instance of RNNCell, to be used for forward direction.
    cell_bw: An instance of RNNCell, to be used for backward direction.
    inputs: The RNN inputs.
      If time_major == False (default), this must be a tensor of shape:
        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
      If time_major == True, this must be a tensor of shape: `[max_time,
        batch_size, ...]`, or a nested tuple of such elements.
    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
      containing the actual lengths for each of the sequences in the batch. If
      not provided, all batch entries are assumed to be full sequences; and time
      reversal is applied from time `0` to `max_time` for each sequence.
    initial_state_fw: (optional) An initial state for the forward RNN. This must
      be a tensor of appropriate type and shape `[batch_size,
      cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a
      tuple of tensors having shapes `[batch_size, s] for s in
      cell_fw.state_size`.
    initial_state_bw: (optional) Same as for `initial_state_fw`, but using the
      corresponding properties of `cell_bw`.
    dtype: (optional) The data type for the initial states and expected output.
      Required if initial_states are not provided or RNN states have a
      heterogeneous dtype.
    parallel_iterations: (Default: 32).  The number of iterations to run in
      parallel.  Those operations which do not have any temporal dependency and
      can be run in parallel, will be.  This parameter trades off time for
      space.  Values >> 1 use more memory but take less time, while smaller
      values use less memory but computations take longer.
    swap_memory: Transparently swap the tensors produced in forward inference
      but needed for back prop from GPU to CPU.  This allows training RNNs which
      would typically not fit on a single GPU, with very minimal (or no)
      performance penalty.
    time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
      these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
      these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
      `time_major = True` is a bit more efficient because it avoids transposes
      at the beginning and end of the RNN calculation.  However, most TensorFlow
      data is batch-major, so by default this function accepts input and emits
      output in batch-major form.
    scope: VariableScope for the created subgraph; defaults to
      "bidirectional_rnn"

  Returns:
    A tuple (outputs, output_states) where:
      outputs: A tuple (output_fw, output_bw) containing the forward and
        the backward rnn output `Tensor`.
        If time_major == False (default),
          output_fw will be a `Tensor` shaped:
          `[batch_size, max_time, cell_fw.output_size]`
          and output_bw will be a `Tensor` shaped:
          `[batch_size, max_time, cell_bw.output_size]`.
        If time_major == True,
          output_fw will be a `Tensor` shaped:
          `[max_time, batch_size, cell_fw.output_size]`
          and output_bw will be a `Tensor` shaped:
          `[max_time, batch_size, cell_bw.output_size]`.
        It returns a tuple instead of a single concatenated `Tensor`, unlike
        in the `bidirectional_rnn`. If the concatenated one is preferred,
        the forward and backward outputs can be concatenated as
        `tf.concat(outputs, 2)`.
      output_states: A tuple (output_state_fw, output_state_bw) containing
        the forward and the backward final states of bidirectional rnn.

  Raises:
    TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
  """
  rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
  rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)

  with vs.variable_scope(scope or "bidirectional_rnn"):
    # Forward direction
    with vs.variable_scope("fw") as fw_scope:
      output_fw, output_state_fw = dynamic_rnn(
          cell=cell_fw,
          inputs=inputs,
          sequence_length=sequence_length,
          initial_state=initial_state_fw,
          dtype=dtype,
          parallel_iterations=parallel_iterations,
          swap_memory=swap_memory,
          time_major=time_major,
          scope=fw_scope)

    # Backward direction
    if not time_major:
      time_axis = 1
      batch_axis = 0
    else:
      time_axis = 0
      batch_axis = 1

    def _reverse(input_, seq_lengths, seq_axis, batch_axis):
      if seq_lengths is not None:
        return array_ops.reverse_sequence(
            input=input_,
            seq_lengths=seq_lengths,
            seq_axis=seq_axis,
            batch_axis=batch_axis)
      else:
        return array_ops.reverse(input_, axis=[seq_axis])

    with vs.variable_scope("bw") as bw_scope:

      def _map_reverse(inp):
        return _reverse(
            inp,
            seq_lengths=sequence_length,
            seq_axis=time_axis,
            batch_axis=batch_axis)

      inputs_reverse = nest.map_structure(_map_reverse, inputs)
      tmp, output_state_bw = dynamic_rnn(
          cell=cell_bw,
          inputs=inputs_reverse,
          sequence_length=sequence_length,
          initial_state=initial_state_bw,
          dtype=dtype,
          parallel_iterations=parallel_iterations,
          swap_memory=swap_memory,
          time_major=time_major,
          scope=bw_scope)

  output_bw = _reverse(
      tmp,
      seq_lengths=sequence_length,
      seq_axis=time_axis,
      batch_axis=batch_axis)

  outputs = (output_fw, output_bw)
  output_states = (output_state_fw, output_state_bw)

  return (outputs, output_states)
Exemplo n.º 31
0
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
                dtype=None, parallel_iterations=None, swap_memory=False,
                time_major=False, scope=None):
    """Creates a recurrent neural network specified by RNNCell `cell`.

    本函数函数体中的所有代码都是在做输入参数是否合法的校验工作,不涉及算法,如果不想看可以整个函数跳过去

    Performs fully dynamic unrolling of `inputs`.

    Example:

    ```python
    # create a BasicRNNCell
    rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)

    # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

    # defining initial state
    initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)

    # 'state' is a tensor of shape [batch_size, cell_state_size]
    outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                       initial_state=initial_state,
                                       dtype=tf.float32)
    ```

    ```python
    # create 2 LSTMCells
    rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]

    # create a RNN cell composed sequentially of a number of RNNCells
    multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)

    # 'outputs' is a tensor of shape [batch_size, max_time, 256]
    # 'state' is a N-tuple where N is the number of LSTMCells containing a
    # tf.contrib.rnn.LSTMStateTuple for each cell
    outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
                                       inputs=data,
                                       dtype=tf.float32)
    ```


    Args:
      cell: An instance of RNNCell.
      inputs: The RNN inputs.
        If `time_major == False` (default), this must be a `Tensor` of shape:
          `[batch_size, max_time, ...]`, or a nested tuple of such
          elements.
        If `time_major == True`, this must be a `Tensor` of shape:
          `[max_time, batch_size, ...]`, or a nested tuple of such
          elements.
        This may also be a (possibly nested) tuple of Tensors satisfying
        this property.  The first two dimensions must match across all the inputs,
        but otherwise the ranks and other shape components may differ.
        In this case, input to `cell` at each time-step will replicate the
        structure of these tuples, except for the time dimension (from which the
        time is taken).
        The input to `cell` at each time step will be a `Tensor` or (possibly
        nested) tuple of Tensors each with dimensions `[batch_size, ...]`.
      sequence_length: (optional) An int32/int64 vector sized `[batch_size]`.
        Used to copy-through state and zero-out outputs when past a batch
        element's sequence length.  So it's more for performance than correctness.
      initial_state: (optional) An initial state for the RNN.
        If `cell.state_size` is an integer, this must be
        a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
        If `cell.state_size` is a tuple, this should be a tuple of
        tensors having shapes `[batch_size, s] for s in cell.state_size`.
      dtype: (optional) The data type for the initial state and expected output.
        Required if initial_state is not provided or RNN state has a heterogeneous
        dtype.
      parallel_iterations: (Default: 32).  The number of iterations to run in
        parallel.  Those operations which do not have any temporal dependency
        and can be run in parallel, will be.  This parameter trades off
        time for space.  Values >> 1 use more memory but take less time,
        while smaller values use less memory but computations take longer.
      swap_memory: Transparently swap the tensors produced in forward inference
        but needed for back prop from GPU to CPU.  This allows training RNNs
        which would typically not fit on a single GPU, with very minimal (or no)
        performance penalty.
      time_major: The shape format of the `inputs` and `outputs` Tensors.
        If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
        If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
        Using `time_major = True` is a bit more efficient because it avoids
        transposes at the beginning and end of the RNN calculation.  However,
        most TensorFlow data is batch-major, so by default this function
        accepts input and emits output in batch-major form.
      scope: VariableScope for the created subgraph; defaults to "rnn".

    Returns:
      A pair (outputs, state) where:

      outputs: The RNN output `Tensor`.

        If time_major == False (default), this will be a `Tensor` shaped:
          `[batch_size, max_time, cell.output_size]`.

        If time_major == True, this will be a `Tensor` shaped:
          `[max_time, batch_size, cell.output_size]`.

        Note, if `cell.output_size` is a (possibly nested) tuple of integers
        or `TensorShape` objects, then `outputs` will be a tuple having the
        same structure as `cell.output_size`, containing Tensors having shapes
        corresponding to the shape data in `cell.output_size`.

      state: The final state.  If `cell.state_size` is an int, this
        will be shaped `[batch_size, cell.state_size]`.  If it is a
        `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
        If it is a (possibly nested) tuple of ints or `TensorShape`, this will
        be a tuple having the corresponding shapes. If cells are `LSTMCells`
        `state` will be a tuple containing a `LSTMStateTuple` for each cell.

    Raises:
      TypeError: If `cell` is not an instance of RNNCell.
      ValueError: If inputs is None or an empty list.
    """

    # Python 动态类型特性:无法像静态类型语言一样在编译期完成类型检查,需要在运行阶段检查输入参数是否合法
    rnn_cell_impl.assert_like_rnncell("cell", cell)

    with vs.variable_scope(scope or "rnn") as varscope:
        # Create a new scope in which the caching device is either
        # determined by the parent scope, or is set to place the cached
        # Variable using the same placement as for the rest of the RNN.
        if _should_cache():
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        # By default, time_major==False and inputs are batch-major: shaped
        #   [batch, time, depth]
        # For internal calculations, we transpose to [time, batch, depth]
        # dynamic RNN接受的其实可以是一个[batch, time, depth]的张量;也可以是一个容器,
        # 容器可以set, dict, list以及set dict list相互嵌套形成的复杂数据结构
        # 换句话说dynamic rnn函数的原始设计其实是可以一次接受多个mini batch的,容器只需保证其中的每一个基本元素,
        # 都是尺寸相同的[batch, time, depth]张量即可
        # 但是这个特性似乎并没有实现,我刚给给Tensorflow报issue,他们还没回复
        # nest_flatten可以将容器给“摊平”成一维数组。在处理完之后,可以用pack_as_sequence函数还原结构
        flat_input = nest.flatten(inputs)

        if not time_major:
            # (B,T,D) => (T,B,D)
            flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
            flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)

        parallel_iterations = parallel_iterations or 32
        if sequence_length is not None:
            sequence_length = math_ops.to_int32(sequence_length)
            if sequence_length.get_shape().ndims not in (None, 1):
                raise ValueError(
                    "sequence_length must be a vector of length batch_size, "
                    "but saw shape: %s" % sequence_length.get_shape())
            sequence_length = array_ops.identity(  # Just to find it in the graph.
                sequence_length, name="sequence_length")

        # 获取batch_size, 我觉得这个函数似乎挺累赘的,可能有什么特殊的设计吧
        batch_size = _best_effort_input_batch_size(flat_input)

        # 创建initial state,如果追溯cell.zero_state,会发现这个函数非常复杂,各种trick用的飞起
        # 但是,这个函数就是很简单的创建一个[batch_size, num_hidden]的零矩阵。之所以写的那么复杂,是因为这个函数是RNNCell下的
        # 而且在RNNCell的各个子类中没有被重写过,所以这个函数要由充分的泛化性能,满足不同RNNCell在初始化时的不同需求
        # (其实主要是满足LSTM的独特初始化需求,LSTM有Cell State和Hidden State两个要初始化的量,其它的Cell一般只有一个)
        if initial_state is not None:
            state = initial_state
        else:
            if not dtype:
                raise ValueError("If there is no initial_state, you must give a dtype.")
            if getattr(cell, "get_initial_state", None) is not None:
                state = cell.get_initial_state(
                    inputs=None, batch_size=batch_size, dtype=dtype)
            else:
                state = cell.zero_state(batch_size, dtype)

        def _assert_has_shape(x, shape):
            x_shape = array_ops.shape(x)
            packed_shape = array_ops.stack(shape)
            return control_flow_ops.Assert(
                math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)),
                ["Expected shape for Tensor %s is " % x.name,
                 packed_shape, " but saw shape: ", x_shape])

        # Sequence Length尺寸是否合法的校验
        if not context.executing_eagerly() and sequence_length is not None:
            # Perform some shape validation
            with ops.control_dependencies(
                    [_assert_has_shape(sequence_length, [batch_size])]):
                sequence_length = array_ops.identity(
                    sequence_length, name="CheckSeqLen")

        # flat_input结构校验完毕
        # 将一开始输入时摊平的数据,按照一开始输入(inputs)的数据结构,重新整合
        inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)

        (outputs, final_state) = _dynamic_rnn_loop(
            cell,
            inputs,
            state,
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory,
            sequence_length=sequence_length,
            dtype=dtype)

        # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
        # If we are performing batch-major calculations, transpose output back
        # to shape [batch, time, depth]
        if not time_major:
            # (T,B,D) => (B,T,D)
            outputs = nest.map_structure(_transpose_batch_time, outputs)

        return (outputs, final_state)
  def __init__(self,
               std_cell,
               cue_cell,
               cue_inputs,
               fact_candidates,
               lengths_for_fact_candidates,
               kgp_initial_goals,
               attention_mechanism,
               encoder_memory=None,
               encoder_memory_len=None,
               attention_layer_size=None,
               k_openness_history=False,
               alignment_history=False,
               cell_input_fn=None,
               output_attention=True,
               initial_cell_state=None,
               common_word_projection=None,
               entity_predict_mode=False,
               copy_predict_mode=False,
               balance_gate=True,
               cue_fact_mode=0,
               cue_fact_mask=False,
               vocab_sizes=None,
               name=None,
               sim_dim=64,
               mid_projection_dim=1280,
               binary_decoding=False,
               attention_layer=None):
    """Construct the `AttentionWrapper`.

    **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
    `AttentionWrapper`, then you must ensure that:

    - The encoder output has been tiled to `beam_width` via
      `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
    - The `batch_size` argument passed to the `zero_state` method of this
      wrapper is equal to `true_batch_size * beam_width`.
    - The initial state created with `zero_state` above contains a
      `cell_state` value containing properly tiled final state from the
      encoder.

    An example:

    ```
    tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
        encoder_outputs, multiplier=beam_width)
    tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch(
        encoder_final_state, multiplier=beam_width)
    tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
        sequence_length, multiplier=beam_width)
    attention_mechanism = MyFavoriteAttentionMechanism(
        num_units=attention_depth,
        memory=tiled_inputs,
        memory_sequence_length=tiled_sequence_length)
    attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
    decoder_initial_state = attention_cell.zero_state(
        dtype, batch_size=true_batch_size * beam_width)
    decoder_initial_state = decoder_initial_state.clone(
        cell_state=tiled_encoder_final_state)
    ```

    Args:
      cell: An instance of `RNNCell`.
      attention_mechanism: A list of `AttentionMechanism` instances or a single
        instance.
      attention_layer_size: A list of Python integers or a single Python
        integer, the depth of the attention (output) layer(s). If None
        (default), use the context as attention at each time step. Otherwise,
        feed the context and cell output into the attention layer to generate
        attention at each time step. If attention_mechanism is a list,
        attention_layer_size must be a list of the same length. If
        attention_layer is set, this must be None.
      alignment_history: Python boolean, whether to store alignment history
        from all time steps in the final output state (currently stored as a
        time major `TensorArray` on which you must call `stack()`).
      cell_input_fn: (optional) A `callable`.  The default is:
        `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
      output_attention: Python bool.  If `True` (default), the output at each
        time step is the attention value.  This is the behavior of Luong-style
        attention mechanisms.  If `False`, the output at each time step is
        the output of `cell`.  This is the behavior of Bhadanau-style
        attention mechanisms.  In both cases, the `attention` tensor is
        propagated to the next time step via the state and is used there.
        This flag only controls whether the attention mechanism is propagated
        up to the next cell in an RNN stack or to the top RNN output.
      initial_cell_state: The initial state value to use for the cell when
        the user calls `zero_state()`.  Note that if this value is provided
        now, and the user uses a `batch_size` argument of `zero_state` which
        does not match the batch size of `initial_cell_state`, proper
        behavior is not guaranteed.
      name: Name to use when creating ops.
      attention_layer: A list of `tf.layers.Layer` instances or a
        single `tf.layers.Layer` instance taking the context and cell output as
        inputs to generate attention at each time step. If None (default), use
        the context as attention at each time step. If attention_mechanism is a
        list, attention_layer must be a list of the same length. If
        attention_layers_size is set, this must be None.

    Raises:
      TypeError: `attention_layer_size` is not None and (`attention_mechanism`
        is a list but `attention_layer_size` is not; or vice versa).
      ValueError: if `attention_layer_size` is not None, `attention_mechanism`
        is a list, and its length does not match that of `attention_layer_size`;
        if `attention_layer_size` and `attention_layer` are set simultaneously.
    """

    if entity_predict_mode or copy_predict_mode:
        assert common_word_projection is not None and vocab_sizes is not None

    self.balance_gate = balance_gate
    if vocab_sizes is not None:
        self._common_vocab_size, self._copy_vocab_size , self._entity_vocab_size = vocab_sizes
        self._vocab_sizes = vocab_sizes

    self._cue_fact_mode = cue_fact_mode
    self._cue_fact_mask = cue_fact_mask
    self._entity_predict_mode = entity_predict_mode
    self._copy_predict_mode = copy_predict_mode
    self._fact_candidates = fact_candidates
    self._lengths_for_fanct_candidates = lengths_for_fact_candidates
    self._batch_size = tf.shape(cue_inputs)[0]
    self._kg_initial_goals = kgp_initial_goals
    if kgp_initial_goals is not None:
        self.decoding_mask_template = kgp_initial_goals
    else:
        self.decoding_mask_template = tf.reduce_max(fact_candidates, -1)
    self._common_word_projection = common_word_projection
    self._encoder_memory = encoder_memory
    self._encoder_memory_len = encoder_memory_len
    self._sim_vec_dim = sim_dim

    self.mid_projection_dim = mid_projection_dim
    self._binary_decoding = binary_decoding

    if copy_predict_mode:
        self._transformed_encoder_memory = tf.layers.dense(self._encoder_memory, units=self._sim_vec_dim,
                                                           activation=tf.nn.tanh,
                                                           name='encoder_memory_transformed')


    super(AttentionWrapper, self).__init__(name=name)

    self.k_openness_history = k_openness_history
    rnn_cell_impl.assert_like_rnncell("cell", std_cell)
    rnn_cell_impl.assert_like_rnncell("cell", cue_cell)
    if isinstance(attention_mechanism, (list, tuple)):
      self._is_multi = True
      attention_mechanisms = attention_mechanism
      for attention_mechanism in attention_mechanisms:
        if not isinstance(attention_mechanism, attention_wrapper.AttentionMechanism):
          raise TypeError(
              "attention_mechanism must contain only instances of "
              "AttentionMechanism, saw type: %s"
              % type(attention_mechanism).__name__)
    else:
      self._is_multi = False
      if not isinstance(attention_mechanism, attention_wrapper.AttentionMechanism):
        raise TypeError(
            "attention_mechanism must be an AttentionMechanism or list of "
            "multiple AttentionMechanism instances, saw type: %s"
            % type(attention_mechanism).__name__)
      attention_mechanisms = (attention_mechanism,)

    if cell_input_fn is None:
      cell_input_fn = (
          lambda inputs, attention: array_ops.concat([inputs, attention], -1))
    else:
      if not callable(cell_input_fn):
        raise TypeError(
            "cell_input_fn must be callable, saw type: %s"
            % type(cell_input_fn).__name__)

    if attention_layer_size is not None and attention_layer is not None:
      raise ValueError("Only one of attention_layer_size and attention_layer "
                       "should be set")

    if attention_layer_size is not None:
      attention_layer_sizes = tuple(
          attention_layer_size
          if isinstance(attention_layer_size, (list, tuple))
          else (attention_layer_size,))
      if len(attention_layer_sizes) != len(attention_mechanisms):
        raise ValueError(
            "If provided, attention_layer_size must contain exactly one "
            "integer per attention_mechanism, saw: %d vs %d"
            % (len(attention_layer_sizes), len(attention_mechanisms)))
      self._attention_layers = tuple(
          layers_core.Dense(
              attention_layer_size,
              name="attention_layer",
              use_bias=False,
              dtype=attention_mechanisms[i].dtype)
          for i, attention_layer_size in enumerate(attention_layer_sizes))
      self._attention_layer_size = sum(attention_layer_sizes)
    elif attention_layer is not None:
      self._attention_layers = tuple(
          attention_layer
          if isinstance(attention_layer, (list, tuple))
          else (attention_layer,))
      if len(self._attention_layers) != len(attention_mechanisms):
        raise ValueError(
            "If provided, attention_layer must contain exactly one "
            "layer per attention_mechanism, saw: %d vs %d"
            % (len(self._attention_layers), len(attention_mechanisms)))
      self._attention_layer_size = sum(
          layer.compute_output_shape(
              [None,
               std_cell.output_size + mechanism.values.shape[-1].value])[-1].value
          for layer, mechanism in zip(
              self._attention_layers, attention_mechanisms))
    else:
      self._attention_layers = None
      self._attention_layer_size = sum(
          attention_mechanism.values.get_shape()[-1].value
          for attention_mechanism in attention_mechanisms)

    self._std_cell = std_cell
    self._cue_cell = cue_cell
    self._cue_inputs = cue_inputs
    self._attention_mechanisms = attention_mechanisms
    self._cell_input_fn = cell_input_fn
    self._output_attention = output_attention
    self._alignment_history = alignment_history
    with ops.name_scope(name, "AttentionWrapperInit"):
      if initial_cell_state is None:
        self._initial_cell_state = None
      else:
        final_state_tensor = nest.flatten(initial_cell_state)[-1]
        state_batch_size = (
            final_state_tensor.shape[0].value
            or array_ops.shape(final_state_tensor)[0])
        error_message = (
            "When constructing AttentionWrapper %s: " % self._base_name +
            "Non-matching batch sizes between the memory "
            "(encoder output) and initial_cell_state.  Are you using "
            "the BeamSearchDecoder?  You may need to tile your initial state "
            "via the tf.contrib.seq2seq.tile_batch function with argument "
            "multiple=beam_width.")
        with ops.control_dependencies(
            self._batch_size_checks(state_batch_size, error_message)):
          self._initial_cell_state = nest.map_structure(
              lambda s: array_ops.identity(s, name="check_initial_cell_state"),
              initial_cell_state)
Exemplo n.º 33
0
def bi_lstm_attnetion(inputs,
                      sequence_length,
                      cell_fw,
                      cell_bw,
                      W_forward,
                      v_forward,
                      W_backward,
                      v_backward,
                      enc_padding_mask,
                      scope=None,
                      time_major=False):

    rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
    rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)

    time_step = inputs.get_shape()[1].value
    outputs_fw = []
    outputs_bw = []
    batch_size = inputs.get_shape()[0].value
    emb_dim = inputs.get_shape()[-1].value
    state_fw = initial_state(shape=[batch_size, cell_bw._num_units],
                             type=inputs.dtype)
    state_bw = initial_state(shape=[batch_size, cell_bw._num_units],
                             type=inputs.dtype)
    with vs.variable_scope(scope or "bidirectional_rnn"):
        # Forward direction
        context_vector = tf.zeros(shape=[batch_size, cell_bw._num_units],
                                  dtype=inputs.dtype)
        with vs.variable_scope("fw") as fw_scope:
            for i in range(time_step):
                print('encoder attention forward : {0}...'.format(i))
                input = inputs[:, i, :]  #[batch,embedding_size]
                x = linear([input] + [context_vector], emb_dim, False)
                cell_output, state_fw = cell_fw(x, state_fw)
                if i > 1:
                    tf.get_variable_scope().reuse_variables()
                if i > 0:
                    context_vector = attention(state_fw,
                                               tf.stack(outputs_fw,
                                                        axis=0), W_forward,
                                               v_forward, enc_padding_mask)
                outputs_fw.append(cell_output)
            outputs_state_fw = state_fw

        # Backward direction
        if not time_major:
            time_dim = 1
            batch_dim = 0
        else:
            time_dim = 0
            batch_dim = 1

        def _reverse(input_, seq_lengths, seq_dim, batch_dim):
            if seq_lengths is not None:
                return array_ops.reverse_sequence(input=input_,
                                                  seq_lengths=seq_lengths,
                                                  seq_dim=seq_dim,
                                                  batch_dim=batch_dim)
            else:
                return array_ops.reverse(input_, axis=[seq_dim])

        with vs.variable_scope("bw") as bw_scope:
            context_vector = tf.zeros(shape=[batch_size, cell_bw._num_units],
                                      dtype=inputs.dtype)
            for i in range(time_step - 1, -1, -1):
                print('encoder attention backward : {0}...'.format(i))
                input = inputs[:, i, :]  # [batch,embedding_size]
                x = linear([input] + [context_vector], emb_dim, False)
                cell_output, state_bw = cell_fw(x, state_bw)
                if i < time_step - 2:
                    tf.get_variable_scope().reuse_variables()
                if i < time_step - 1:
                    context_vector = attention(state_fw,
                                               tf.stack(outputs_fw, axis=0),
                                               W_backward, v_backward,
                                               enc_padding_mask[::-1])
                outputs_bw.append(cell_output)
            outputs_state_bw = state_bw

    tmp = tf.stack(outputs_bw, axis=1)
    outputs_fw = tf.stack(outputs_fw, axis=1)
    outputs_bw = _reverse(tmp,
                          seq_lengths=sequence_length,
                          seq_dim=time_dim,
                          batch_dim=batch_dim)

    outputs = (outputs_fw, outputs_bw)
    output_states = (outputs_state_fw, outputs_state_bw)

    return (outputs, output_states)
Exemplo n.º 34
0
def static_rnn(cell,
               inputs,
               initial_state=None,
               dtype=None,
               sequence_length=None,
               scope=None):
  """Creates a recurrent neural network specified by RNNCell `cell`.
  The simplest form of RNN network generated is:
  ```python
    state = cell.zero_state(...)
    outputs = []
    for input_ in inputs:
      output, state = cell(input_, state)
      outputs.append(output)
    return (outputs, state)
  ```
  However, a few other options are available:
  An initial state can be provided.
  If the sequence_length vector is provided, dynamic calculation is performed.
  This method of calculation does not compute the RNN steps past the maximum
  sequence length of the minibatch (thus saving computational time),
  and properly propagates the state at an example's sequence length
  to the final state output.
  The dynamic calculation performed is, at time `t` for batch row `b`,
  ```python
    (output, state)(b, t) =
      (t >= sequence_length(b))
        ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
        : cell(input(b, t), state(b, t - 1))
  ```
  Args:
    cell: An instance of RNNCell.
    inputs: A length T list of inputs, each a `Tensor` of shape
      `[batch_size, input_size]`, or a nested tuple of such elements.
    initial_state: (optional) An initial state for the RNN.
      If `cell.state_size` is an integer, this must be
      a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
      If `cell.state_size` is a tuple, this should be a tuple of
      tensors having shapes `[batch_size, s] for s in cell.state_size`.
    dtype: (optional) The data type for the initial state and expected output.
      Required if initial_state is not provided or RNN state has a heterogeneous
      dtype.
    sequence_length: Specifies the length of each sequence in inputs.
      An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
    scope: VariableScope for the created subgraph; defaults to "rnn".
  Returns:
    A pair (outputs, state) where:
    - outputs is a length T list of outputs (one for each input), or a nested
      tuple of such elements.
    - state is the final state
  Raises:
    TypeError: If `cell` is not an instance of RNNCell.
    ValueError: If `inputs` is `None` or an empty list, or if the input depth
      (column size) cannot be inferred from inputs via shape inference.
  """
  rnn_cell_impl.assert_like_rnncell("cell", cell)
  if not nest.is_sequence(inputs):
    raise TypeError("inputs must be a sequence")
  if not inputs:
    raise ValueError("inputs must not be empty")

  outputs = []
  # Create a new scope in which the caching device is either
  # determined by the parent scope, or is set to place the cached
  # Variable using the same placement as for the rest of the RNN.
  with vs.variable_scope(scope or "rnn") as varscope:
    if _should_cache():
      if varscope.caching_device is None:
        varscope.set_caching_device(lambda op: op.device)

    # Obtain the first sequence of the input
    first_input = inputs
    while nest.is_sequence(first_input):
      first_input = first_input[0]

    # Temporarily avoid EmbeddingWrapper and seq2seq badness
    # TODO(lukaszkaiser): remove EmbeddingWrapper
    if first_input.get_shape().ndims != 1:

      input_shape = first_input.get_shape().with_rank_at_least(2)
      fixed_batch_size = input_shape[0]

      flat_inputs = nest.flatten(inputs)
      for flat_input in flat_inputs:
        input_shape = flat_input.get_shape().with_rank_at_least(2)
        batch_size, input_size = input_shape[0], input_shape[1:]
        fixed_batch_size.merge_with(batch_size)
        for i, size in enumerate(input_size):
          if size.value is None:
            raise ValueError(
                "Input size (dimension %d of inputs) must be accessible via "
                "shape inference, but saw value None." % i)
    else:
      fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]

    if fixed_batch_size.value:
      batch_size = fixed_batch_size.value
    else:
      batch_size = array_ops.shape(first_input)[0]
    if initial_state is not None:
      state = initial_state
    else:
      if not dtype:
        raise ValueError("If no initial_state is provided, "
                         "dtype must be specified")
      if getattr(cell, "get_initial_state", None) is not None:
        state = cell.get_initial_state(
            inputs=None, batch_size=batch_size, dtype=dtype)
      else:
        state = cell.zero_state(batch_size, dtype)

    if sequence_length is not None:  # Prepare variables
      sequence_length = ops.convert_to_tensor(
          sequence_length, name="sequence_length")
      if sequence_length.get_shape().ndims not in (None, 1):
        raise ValueError(
            "sequence_length must be a vector of length batch_size")

      def _create_zero_output(output_size):
        # convert int to TensorShape if necessary
        size = _concat(batch_size, output_size)
        output = array_ops.zeros(
            array_ops.stack(size), _infer_state_dtype(dtype, state))
        shape = _concat(fixed_batch_size.value, output_size, static=True)
        output.set_shape(tensor_shape.TensorShape(shape))
        return output

      output_size = cell.output_size
      flat_output_size = nest.flatten(output_size)
      flat_zero_output = tuple(
          _create_zero_output(size) for size in flat_output_size)
      zero_output = nest.pack_sequence_as(
          structure=output_size, flat_sequence=flat_zero_output)

      sequence_length = math_ops.to_int32(sequence_length)
      min_sequence_length = math_ops.reduce_min(sequence_length)
      max_sequence_length = math_ops.reduce_max(sequence_length)

    # Keras RNN cells only accept state as list, even if it's a single tensor.
    is_keras_rnn_cell = _is_keras_rnn_cell(cell)
    if is_keras_rnn_cell and not nest.is_sequence(state):
      state = [state]
    for time, input_ in enumerate(inputs):
      if time > 0:
        varscope.reuse_variables()
      # pylint: disable=cell-var-from-loop
      call_cell = lambda: cell.call(input_, state, time)
      # pylint: enable=cell-var-from-loop
      if sequence_length is not None:
        (output, state) = _rnn_step(
            time=time,
            sequence_length=sequence_length,
            min_sequence_length=min_sequence_length,
            max_sequence_length=max_sequence_length,
            zero_output=zero_output,
            state=state,
            call_cell=call_cell,
            state_size=cell.state_size)
      else:
        (output, state) = call_cell()
      outputs.append(output)
    # Keras RNN cells only return state as list, even if it's a single tensor.
    if is_keras_rnn_cell and len(state) == 1:
      state = state[0]

    return (outputs, state)
def raw_rnn_for_beam_search(cell,
                            loop_fn,
                            parallel_iterations=None,
                            swap_memory=False,
                            scope=None):
    rnn_cell_impl.assert_like_rnncell("cell", cell)

    if not callable(loop_fn):
        raise TypeError("loop_fn must be a callable")

    parallel_iterations = parallel_iterations or 32

    # Create a new scope in which the caching device is either
    # determined by the parent scope, or is set to place the cached
    # Variable using the same placement as for the rest of the RNN.
    with vs.variable_scope(scope or "rnn") as varscope:
        if not context.executing_eagerly():
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        time = constant_op.constant(0, dtype=dtypes.int32)
        (elements_finished, next_input, initial_state,
         emit_predicted_ids_structure,
         init_log_probs, init_beam_finished, _) = loop_fn(
             time, None, None, None,
             None)  # time, cell_output, cell_state, log_probs, beam_finished
        flat_input = nest.flatten(next_input)

        # Need a surrogate log_probs, beam_finished for the while_loop if none is available.
        log_probs = (init_log_probs if init_log_probs is not None else
                     constant_op.constant(0, dtype=dtypes.float32))
        beam_finished = (init_beam_finished if init_beam_finished is not None
                         else constant_op.constant(False, dtype=dtypes.bool))
        penalty_lengths = array_ops.zeros_like(log_probs, dtype=dtypes.float32)
        final_log_probs = array_ops.ones_like(log_probs, dtype=dtypes.float32)

        input_shape = [input_.get_shape() for input_ in flat_input]
        static_batch_size = input_shape[0][0]

        for input_shape_i in input_shape:
            # Static verification that batch sizes all match
            static_batch_size.merge_with(input_shape_i[0])

        batch_size = static_batch_size.value
        const_batch_size = batch_size
        if batch_size is None:
            batch_size = array_ops.shape(flat_input[0])[0]

        # nest.assert_same_structure(initial_state, cell.state_size)
        #  Note: remove above line because state will be tuple with number of elements based on beam width
        state = initial_state
        flat_state = nest.flatten(state)
        flat_state = [ops.convert_to_tensor(s) for s in flat_state]
        state = nest.pack_sequence_as(structure=state,
                                      flat_sequence=flat_state)

        if emit_predicted_ids_structure is not None:
            flat_emit_structure = nest.flatten(emit_predicted_ids_structure)
            flat_emit_size = [
                emit.shape
                if emit.shape.is_fully_defined() else array_ops.shape(emit)
                for emit in flat_emit_structure
            ]
            flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
        else:
            emit_predicted_ids_structure = cell.output_size
            flat_emit_size = nest.flatten(emit_predicted_ids_structure)
            flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

        flat_emit_ta = [
            tensor_array_ops.TensorArray(
                dtype=dtype_i,
                dynamic_size=True,
                element_shape=(tensor_shape.TensorShape([
                    const_batch_size
                ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))),
                size=0,
                clear_after_read=False,
                name="rnn_output_%d" % i)
            for i, (dtype_i,
                    size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
        ]
        predicted_ids_ta = nest.pack_sequence_as(
            structure=emit_predicted_ids_structure, flat_sequence=flat_emit_ta)
        flat_zero_emit = [
            array_ops.zeros(_concat(batch_size, size_i), dtype_i)
            for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)
        ]
        zero_emit = nest.pack_sequence_as(
            structure=emit_predicted_ids_structure,
            flat_sequence=flat_zero_emit)

        # parent_ids_in_beam_ta = tensor_array_ops.TensorArray(dtypes.int32, size=0,
        #     dynamic_size=True, clear_after_read=False).write(0, initial_parent_ids_value)
        parent_ids_in_beam_ta = tensor_array_ops.TensorArray(
            dtypes.int32, size=0, dynamic_size=True, clear_after_read=False)
        beam_width = array_ops.shape(log_probs)[-1]
        index_for_finished_beam = array_ops.stack(
            [math_ops.range(beam_width)] * batch_size)

        def condition(unused_time, elements_finished, *_):
            return math_ops.logical_not(math_ops.reduce_all(elements_finished))

        def body(time, elements_finished, current_input, _predicted_ids_ta,
                 state, log_probs, parent_index_ta, beam_finished,
                 penalty_lengths, _final_log_probs):
            """Internal while loop body for raw_rnn.

            Args:
              time: time scalar.
              elements_finished: batch-size vector.
              current_input: possibly nested tuple of input tensors.
              _predicted_ids_ta: possibly nested tuple of output TensorArrays.
              state: possibly nested tuple of state tensors.
              log_probs: possibly nested tuple of loop state tensors.
              parent_index_ta: index of previous word in beam (use in finding path)
              _final_log_probs: log_probs table indicate the log_prob values of beams according to penalty_length
            Returns:
              Tuple having the same size as Args but with updated values.
            """
            # ===========new code==================
            tuple_arr = [
                cell(_input, _state)
                for _input, _state in zip(current_input, state)
            ]
            # ====================================
            # (next_output, cell_state) = cell(current_input, state)
            #  Note: above line is removed because beam search

            # =============new code================
            next_output = tuple(_output for _output, _ in tuple_arr)
            cell_state = tuple(_state for _, _state in tuple_arr)
            # =====================================
            nest.assert_same_structure(state, cell_state)
            # nest.assert_same_structure(cell.output_size, next_output)
            #  Note: above line is removed because beam search

            next_time = time + 1
            (next_finished, next_input, next_state, predicted_ids,
             new_log_probs, new_beam_finished,
             parent_indexs) = loop_fn(time, next_output, cell_state, log_probs,
                                      beam_finished)

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(_predicted_ids_ta, predicted_ids)
            # predicted_ids = logging_ops.Print(predicted_ids, [predicted_ids])
            # <eos> if finished at previous step
            predicted_ids = array_ops.where(
                beam_finished,
                array_ops.fill(array_ops.shape(predicted_ids), eos_vocab_id),
                predicted_ids)
            # predicted_ids = logging_ops.Print(predicted_ids, [predicted_ids[1]], message='ids[1] after where clause=')
            # should predict <eos> if finished
            # first update final_log_probs
            final_log_probs_not_updated = math_ops.equal(
                _final_log_probs, 1.)  # initial value is 1.0
            new_final_log_probs = array_ops.where(
                math_ops.logical_and(new_beam_finished,
                                     final_log_probs_not_updated),
                new_log_probs, _final_log_probs)  # stay unchange if updated
            # new_final_log_probs = logging_ops.Print(new_final_log_probs, [new_final_log_probs[10]], message='new_final_log_probs=')
            # new_log_probs = logging_ops.Print(new_log_probs, [new_log_probs[10]], message='new_log_probs=')
            # then change log_probs of finished beams to -inf to never be chosen by top_k
            new_log_probs = array_ops.where(
                new_beam_finished,
                array_ops.fill(array_ops.shape(log_probs), -inf),
                new_log_probs)
            # new_log_probs = logging_ops.Print(new_log_probs, [new_log_probs])
            penalty_lengths = array_ops.where(
                new_beam_finished, penalty_lengths,
                penalty_lengths + 1)  # +1 if NOT finished

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                def copy_fn(cur_i, cand_i):
                    # TensorArray and scalar get passed through.
                    if isinstance(cur_i, tensor_array_ops.TensorArray):
                        return cand_i
                    if cur_i.shape.ndims == 0:
                        return cand_i
                    # Otherwise propagate the old or the new value.
                    with ops.colocate_with(cand_i):
                        return array_ops.where(elements_finished, cur_i,
                                               cand_i)

                return nest.map_structure(copy_fn, current, candidate)

            predicted_ids = _copy_some_through(zero_emit, predicted_ids)
            # predicted_ids = logging_ops.Print(predicted_ids, [predicted_ids[1]], message='ids[1] after copy_some_through')
            next_state = _copy_some_through(state, next_state)

            _predicted_ids_ta = nest.map_structure(
                lambda ta, emit: ta.write(time, emit), _predicted_ids_ta,
                predicted_ids)

            # parent_indexs = control_flow_ops.cond(math_ops.equal(time, 0),
            #                 lambda: parent_indexs,  # pass it if time=0 (all filled with -1)
            #                 lambda: array_ops.where(new_beam_finished,
            #                         parent_index_ta.read(time - 1), parent_indexs))  # prev_ids if beam is finish)

            parent_indexs = control_flow_ops.cond(
                math_ops.equal(time, 0),
                lambda:
                parent_indexs,  # pass it if time=0 (all filled with -1)
                lambda: array_ops.where(
                    beam_finished,
                    index_for_finished_beam,  # true index 0,1,2,...,beam
                    parent_indexs))

            # parent_indexs = logging_ops.Print(parent_indexs, [parent_indexs[1]], message='parent[1]=')
            parent_index_ta = parent_index_ta.write(time, parent_indexs)

            elements_finished = math_ops.logical_or(elements_finished,
                                                    next_finished)

            return (next_time, elements_finished, next_input,
                    _predicted_ids_ta, next_state, new_log_probs,
                    parent_index_ta, new_beam_finished, penalty_lengths,
                    new_final_log_probs)

        returned = control_flow_ops.while_loop(
            condition,
            body,
            loop_vars=[
                time, elements_finished, next_input, predicted_ids_ta, state,
                log_probs, parent_ids_in_beam_ta, beam_finished,
                penalty_lengths, final_log_probs
            ],
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory)

        (_, _, _, predicted_ids_ta, _, _, parent_ids_ta, _, penalties,
         final_log_probs) = returned

        # for some elements of final_log_probs still are 1.0 (stop by reaching sentence length)
        # we need to turn it into -inf so it won't be chosen after normalize
        final_log_probs_still_not_updated = math_ops.equal(final_log_probs, 1.)
        final_log_probs = array_ops.where(
            final_log_probs_still_not_updated,
            array_ops.fill(array_ops.shape(final_log_probs), -inf),
            final_log_probs)

        return predicted_ids_ta, parent_ids_ta, penalties, final_log_probs