예제 #1
0
  def __init__(self, cell, helper, initial_state, output_layer=None):
    """Initialize BasicDecoder.

    Args:
      cell: An `RNNCell` instance.
      helper: A `Helper` instance.
      initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
        The initial state of the RNNCell.
      output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
        `tf.layers.Dense`. Optional layer to apply to the RNN output prior
        to storing the result or sampling.

    Raises:
      TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
    """
    if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
      raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
    if not isinstance(helper, helper_py.Helper):
      raise TypeError("helper must be a Helper, received: %s" % type(helper))
    if (output_layer is not None
        and not isinstance(output_layer, layers_base.Layer)):
      raise TypeError(
          "output_layer must be a Layer, received: %s" % type(output_layer))
    self._cell = cell
    self._helper = helper
    self._initial_state = initial_state
    self._output_layer = output_layer
예제 #2
0
 def __init__(self, cell, helper, initial_state, output_layer=None):
   """Initialize CustomDecoder.
   Args:
     cell: An `RNNCell` instance.
     helper: A `Helper` instance.
     initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
       The initial state of the RNNCell.
     output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
       `tf.layers.Dense`. Optional layer to apply to the RNN output prior
       to storing the result or sampling.
   Raises:
     TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
   """
   if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
     raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
   if not isinstance(helper, helper_py.Helper):
     raise TypeError("helper must be a Helper, received: %s" % type(helper))
   if (output_layer is not None
       and not isinstance(output_layer, layers_base.Layer)):
     raise TypeError(
         "output_layer must be a Layer, received: %s" % type(output_layer))
   self._cell = cell
   self._helper = helper
   self._initial_state = initial_state
   self._output_layer = output_layer
예제 #3
0
    def __init__(self, cell, helper, initial_state, output_layer=None):
        '''
        customed Decoder, refer to: https://blog.csdn.net/thriving_fcl/article/details/74165062
        :param cell: 'RNNCell' instance
        :param helper: 'Helper' instance
        :param initial_state: The inistial state of RNNCell -> encoder output
        :param output_layer: tf.layers.Layer -> tf.layers.Dense
        '''
        if parse_version(tf.__version__) >= parse_version('1.10'):
            rnn_cell_impl.assert_like_rnncell(type(cell), cell)
        else:
            if not rnn_cell_impl._like_rnncell(cell):
                raise TypeError('cell must be RNNCell, receiver: %s' %
                                type(cell))

        if not isinstance(helper, helper_py.Helper):
            raise TypeError('helper must be a Helper, received: %s' %
                            type(helper))
        if output_layer is not None and not isinstance(output_layer,
                                                       layers_base.Layer):
            raise TypeError('output_layer must be a Layer, receive: %s' %
                            type(output_layer))
        self._cell = cell
        self._helper = helper
        self._initial_state = initial_state
        self._output_layer = output_layer
예제 #4
0
    def __init__(self, cell, helper, initial_state, output_layer=None):
        """Initialize BasicVectorDecoder.

        Args:
            cell: An `RNNCell` instance.
            helper: A `Helper` instance.
            initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
                The initial state of the RNNCell.
            output_layer:An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`.
                If not provided, use 1 fc layer.

        Raises:
            TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
        """
        if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
            raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
        if not isinstance(helper, helper_py.Helper):
            raise TypeError("helper must be a Helper, received: %s" % type(helper))
        if (output_layer is not None and not isinstance(output_layer, layers_base.Layer)):
            raise TypeError(
                "output_layer must be a Layer, received: %s" % type(output_layer))
        self._cell = cell
        self._helper = helper
        self._initial_state = initial_state
        if output_layer is None:
            self._output_layer = layer_core.Dense(2, use_bias=True,
                                                  name="stop_predictor")
        else:
            self._output_layer = output_layer
예제 #5
0
  def __init__(self, cell,
               input_size=None, state_is_tuple=True, reuse=None,emb_M3=None,emb_M4k=None):

    super(NLabelNoAttentionCellWrapper, self).__init__(_reuse=reuse)
  #
    if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
      raise TypeError("The parameter cell is not RNNCell.")
    if nest.is_sequence(cell.state_size) and not state_is_tuple:
      raise ValueError("Cell returns tuple of states, but the flag "
                       "state_is_tuple is not set. State size is: %s"
                       % str(cell.state_size))

    if not state_is_tuple:
      logging.warn(
          "%s: Using a concatenated state is slower and will soon be "
          "deprecated.  Use state_is_tuple=True.", self)

    self._state_is_tuple = state_is_tuple
    self._cell = cell

    self._input_size = input_size


    self._reuse = reuse
    self._linear1 = None
    self._linear2 = None
    self._linear3 = None

    self.emb_M3 = emb_M3
    self.emb_M4k = emb_M4k
    self.config = Config()
    self._output_size=  cell.output_size
예제 #6
0
 def _check_inputs(self, cell, state_is_tuple):
     if not rnn_cell_impl._like_rnncell(cell):  
         raise TypeError("The parameter cell is not RNNCell.")
     if Config.ATTN_TYPE == Config.attn_temporal and Config.ATTN_TEMPORAL_WINDOW <= 0:
         raise ValueError("Config.ATTN_TEMPORAL_WINDOW should be greater than zero, got %s" % str(Config.ATTN_TEMPORAL_WINDOW))
     if not state_is_tuple:
         raise ValueError("Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True.")
예제 #7
0
    def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None,
               input_size=None, state_is_tuple=True, reuse=None):
        """Create a cell with attention.
        Args:
          cell: an RNNCell, an attention is added to it.
          attn_length: integer, the size of an attention window.
          attn_size: integer, the size of an attention vector. Equal to
              cell.output_size by default.
          attn_vec_size: integer, the number of convolutional features calculated
              on attention state and a size of the hidden layer built from
              base cell state. Equal attn_size to by default.
          input_size: integer, the size of a hidden linear layer,
              built from inputs and attention. Derived from the input tensor
              by default.
          state_is_tuple: If True, accepted and returned states are n-tuples, where
            `n = len(cells)`.  By default (False), the states are all
            concatenated along the column axis.
          reuse: (optional) Python boolean describing whether to reuse variables
            in an existing scope.  If not `True`, and the existing scope already has
            the given variables, an error is raised.
        Raises:
          TypeError: if cell is not an RNNCell.
          ValueError: if cell returns a state tuple but the flag
              `state_is_tuple` is `False` or if attn_length is zero or less.
        """
        super(AttentionCellWrapper, self).__init__(_reuse=reuse)

        if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
            raise TypeError("The parameter cell is not RNNCell.")
        if nest.is_sequence(cell.state_size) and not state_is_tuple:
            raise ValueError("Cell returns tuple of states, but the flag "
                           "state_is_tuple is not set. State size is: %s"
                           % str(cell.state_size))
        if attn_length <= 0:
            raise ValueError("attn_length should be greater than zero, got %s"
                           % str(attn_length))
        if not state_is_tuple:
            logging.warn(
              "%s: Using a concatenated state is slower and will soon be "
              "deprecated.  Use state_is_tuple=True.", self)
        if attn_size is None:
            attn_size = cell.output_size
        if attn_vec_size is None:
            attn_vec_size = attn_size
        self._state_is_tuple = state_is_tuple
        self._cell = cell
        self._attn_vec_size = attn_vec_size
        if input_size:
            self._input_size = input_size -1  # discount phase
        else:
            self._input_size = input_size
        self._attn_size = attn_size
        self._attn_length = attn_length
        self._reuse = reuse
        self._linear1 = None
        self._linear2 = None
        self._linear3 = None
        self.phase = None
    def __init__(self, cell, output_size, sentence_index, activation=None):
        super(CopyWrapper, self).__init__()
        if not _like_rnncell(cell_name="copy_cell", cell=cell):
            raise TypeError('The parameter cell is not RNNCell.')

        self._cell = cell
        self._output_size = output_size
        self._sentence_index = sentence_index
        self._activation = activation
        self._linear = None
예제 #9
0
    def __init__(self, cell, helper, initial_state, output_layer=None):
        """Initialize SonnetBasicDecoder.

    Args:
      cell: An `RNNCell` instance.
      helper: A `Helper` instance.
      initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
        The initial state of the RNNCell.
      output_layer: (Optional) An instance of `tf.layers.Layer`, e.g.,
        `tf.layers.Dense` or snt.AbstractModule, e.g. snt.Linear.
        Optional layer to apply to the RNN output prior to storing the result or
        sampling.

    Raises:
      TypeError: if `cell` is not an instance of `RNNCell`, `helper`
        is not an instance of `Helper`, or `output_layer` is not an instance
        of `tf.layers.Layer`.
    """

        if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
            raise TypeError("cell must be an RNNCell, received: %s" %
                            type(cell))
        if not isinstance(helper, helper_py.Helper):
            raise TypeError("helper must be a Helper, received: %s" %
                            type(helper))
        if output_layer is not None:
            if isinstance(output_layer, snt.AbstractModule):
                output_layer_type = 'AbstractModule'
            elif isinstance(output_layer, layers_base._Layer):  # pylint: disable=protected-access
                output_layer_type = 'Layer'
            else:
                raise TypeError("output_layer must be a Layer or a "
                                "Sonnet AbstractModule, received: %s" %
                                type(output_layer))
        else:
            output_layer_type = None
        self._cell = cell
        self._helper = helper
        self._initial_state = initial_state
        self._output_layer = output_layer
        self._output_layer_type = output_layer_type
    def __init__(self,
                 cell,
                 attention_mechanism,
                 attention_layer_size=None,
                 alignment_history=False,
                 cell_input_fn=None,
                 attention_input_fn=None,
                 output_attention=True,
                 initial_cell_state=None,
                 name=None):
        """Construct the `AttentionWrapper`.

    Args:
      cell: An instance of `RNNCell`.
      attention_mechanism: An instance of `AttentionMechanism`.
      attention_layer_size: Python integer, the depth of the attention (output)
        layer. If None (default), use the context as attention at each time
        step. Otherwise, feed the context and cell output into the attention
        layer to generate attention at each time step.
      alignment_history: Python boolean, whether to store alignment history
        from all time steps in the final output state (currently stored as a
        time major `TensorArray` on which you must call `stack()`).
      cell_input_fn: (optional) A `callable`.  The default is:
        `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
      output_attention: Python bool.  If `True` (default), the output at each
        time step is the attention value.  This is the behavior of Luong-style
        attention mechanisms.  If `False`, the output at each time step is
        the output of `cell`.  This is the beahvior of Bhadanau-style
        attention mechanisms.  In both cases, the `attention` tensor is
        propagated to the next time step via the state and is used there.
        This flag only controls whether the attention mechanism is propagated
        up to the next cell in an RNN stack or to the top RNN output.
      initial_cell_state: The initial state value to use for the cell when
        the user calls `zero_state()`.  Note that if this value is provided
        now, and the user uses a `batch_size` argument of `zero_state` which
        does not match the batch size of `initial_cell_state`, proper
        behavior is not guaranteed.
      name: Name to use when creating ops.
    """
        super(AttentionWrapper, self).__init__(name=name)
        if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
            raise TypeError("cell must be an RNNCell, saw type: %s" %
                            type(cell).__name__)
        if not isinstance(attention_mechanism, AttentionMechanism):
            raise TypeError(
                "attention_mechanism must be a AttentionMechanism, saw type: %s"
                % type(attention_mechanism).__name__)

        # -- what gets inputed to the core RNN cell we're wrapping around
        if cell_input_fn is None:
            cell_input_fn = (lambda inputs, attention: array_ops.concat(
                [inputs, attention], -1))
        else:
            if not callable(cell_input_fn):
                raise TypeError(
                    "cell_input_fn must be callable, saw type: %s" %
                    type(cell_input_fn).__name__)

########### ADDED TO ALLOW DIFFERENT INPUTS TO ATTENTION MECHANISM #############

# what the attention unit gets as the query
        if attention_input_fn is None:
            attention_input_fn = (lambda _, state: state)
        else:
            if not callable(attention_input_fn):
                raise TypeError(
                    "attention_input_fn must be callable, saw type: %s" %
                    type(attention_input_fn).__name__)


############## DONE ####################################################

        if attention_layer_size is not None:
            self._attention_layer = layers_core.Dense(attention_layer_size,
                                                      name="attention_layer",
                                                      use_bias=False)
            self._attention_size = attention_layer_size
        else:
            self._attention_layer = None
            self._attention_size = attention_mechanism.values.get_shape(
            )[-1].value

        self._cell = cell
        self._attention_mechanism = attention_mechanism
        self._cell_input_fn = cell_input_fn
        self._attention_input_fn = attention_input_fn

        self._output_attention = output_attention
        self._alignment_history = alignment_history
        with ops.name_scope(name, "AttentionWrapperInit"):
            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (final_state_tensor.shape[0].value
                                    or array_ops.shape(final_state_tensor)[0])
                error_message = (
                    "When constructing AttentionWrapper %s: " % self._base_name
                    + "Non-matching batch sizes between the memory "
                    "(encoder output) and initial_cell_state.  Are you using "
                    "the BeamSearchDecoder?  You may need to tile your initial state "
                    "via the tf.contrib.seq2seq.tile_batch function with argument "
                    "multiple=beam_width.")
                with ops.control_dependencies([
                        check_ops.assert_equal(
                            state_batch_size,
                            self._attention_mechanism.batch_size,
                            message=error_message)
                ]):
                    self._initial_cell_state = nest.map_structure(
                        lambda s: array_ops.identity(
                            s, name="check_initial_cell_state"),
                        initial_cell_state)
예제 #11
0
    def __init__(self,
                 cell,
                 memory,
                 memory_sequence_length,
                 output_layer,
                 max_oovs,
                 batch_size,
                 memory_full_vocab,
                 first_lv_sim_func,
                 second_lv_sim_func,
                 attention_layer_size=None,
                 alignment_history=False,
                 cell_input_fn=None,
                 output_attention=False,
                 output_generation_distribution=False,
                 output_copy_distribution=False,
                 output_combined_distribution=True,
                 initial_cell_state=None,
                 unk_id=None,
                 name=None):

        super(HierarchicalAttnPointerWrapper, self).__init__(name=name)
        if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
            raise TypeError("cell must be an RNNCell, saw type: %s" %
                            type(cell).__name__)

        self._is_multi = False

        if cell_input_fn is None:
            cell_input_fn = (lambda inputs, attention: array_ops.concat(
                [inputs, attention], -1))
        else:
            if not callable(cell_input_fn):
                raise TypeError(
                    "cell_input_fn must be callable, saw type: %s" %
                    type(cell_input_fn).__name__)

        if attention_layer_size is not None:
            attention_layer_sizes = tuple(attention_layer_size if isinstance(
                attention_layer_size, (list,
                                       tuple)) else (attention_layer_size, ))
            if len(attention_layer_sizes) != 1:
                raise ValueError(
                    "If provided, attention_layer_size must contain exactly one "
                    "integer per attention_mechanism, saw: %d vs 1" %
                    (len(attention_layer_sizes)))

            self._attention_layers = tuple(
                layers_core.Dense(attention_layer_size,
                                  name="attention_layer",
                                  use_bias=False)
                for attention_layer_size in attention_layer_sizes)
            self._attention_layer_size = sum(attention_layer_sizes)
        else:
            self._attention_layers = None
            self._attention_layer_size = memory.get_shape()[-1].value

        self._cell = cell
        self._cell_input_fn = cell_input_fn
        self._output_attention = output_attention
        self._output_generation_distribution = output_generation_distribution
        self._output_copy_distribution = output_copy_distribution
        self._output_combined_distribution = output_combined_distribution
        self._unk_id = unk_id
        self._alignment_history = alignment_history
        self._output_layer = output_layer
        self._max_oovs = max_oovs
        self._batch_size = batch_size

        [self._b, self._k, _, h] = memory.get_shape().as_list()
        #self._k = tf.shape(memory)[1].value
        #self._b = tf.shape(memory)[0].value
        #h = tf.shape(memory)[-1].value

        b = self._b
        k = self._k

        mem_reshaped = tf.reshape(memory, [b * k, -1, h])
        print(mem_reshaped.get_shape().as_list())
        mem_mask_reshaped = tf.reshape(memory_sequence_length, [-1])

        self._memory = tf.reshape(
            _prepare_memory(mem_reshaped, mem_mask_reshaped, False),
            [b, k, -1, h])
        self._memory_full_vocab = memory_full_vocab

        self._attention_mechanisms = [None]  # placeholder

        with tf.variable_scope("first_lv_attn"):
            self._first_lv_sim_func = first_lv_sim_func

        with tf.variable_scope("second_lv_attn"):
            self._second_lv_sim_func = second_lv_sim_func

        if self._output_combined_distribution or \
                self._output_generation_distribution or \
                self._output_copy_distribution or \
                self._output_attention:
            assert self._output_combined_distribution ^\
                self._output_generation_distribution ^\
                self._output_copy_distribution ^\
                self._output_attention, "Can only output one type!"

        if self._output_combined_distribution or self._output_copy_distribution:
            assert self._unk_id is not None

        with ops.name_scope(name, "AttnPointerWrapperInit"):
            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (final_state_tensor.shape[0].value
                                    or array_ops.shape(final_state_tensor)[0])
                error_message = (
                    "When constructing AttnPointerWrapper %s: " %
                    self._base_name +
                    "Non-matching batch sizes between the memory "
                    "(encoder output) and initial_cell_state.  Are you using "
                    "the BeamSearchDecoder?  You may need to tile your initial state "
                    "via the tf.contrib.seq2seq.tile_batch function with argument "
                    "multiple=beam_width.")
                with ops.control_dependencies(
                        self._batch_size_checks(state_batch_size,
                                                error_message)):
                    self._initial_cell_state = nest.map_structure(
                        lambda s: array_ops.identity(
                            s, name="check_initial_cell_state"),
                        initial_cell_state)
예제 #12
0
    def __init__(self,
                 cell,
                 seq_len,
                 len_embeddings,
                 cell_input_fn=None,
                 initial_cell_state=None,
                 name=None):
        """Construct the `AttentionWrapper`.

        Args:
          cell: An instance of `RNNCell`.
          alignment_inputs: inputs
          cell_input_fn: (optional) A `callable`.  The default is:
            `lambda inputs, alignment_input: array_ops.concat([inputs, alignment_input], -1)`.
          initial_cell_state: The initial state value to use for the cell when
            the user calls `zero_state()`.  Note that if this value is provided
            now, and the user uses a `batch_size` argument of `zero_state` which
            does not match the batch size of `initial_cell_state`, proper
            behavior is not guaranteed.
          name: Name to use when creating ops.

        Raises:
          TypeError: `attention_layer_size` is not None and (`attention_mechanism`
            is a list but `attention_layer_size` is not; or vice versa).
          ValueError: if `attention_layer_size` is not None, `attention_mechanism`
            is a list, and its length does not match that of `attention_layer_size`.
        """
        super(LenControlWrapper, self).__init__(name=name)
        if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
            raise TypeError("cell must be an RNNCell, saw type: %s" %
                            type(cell).__name__)

        if cell_input_fn is None:
            cell_input_fn = (lambda inputs, len_embedding: array_ops.concat(
                [inputs, len_embedding], -1))
        else:
            if not callable(cell_input_fn):
                raise TypeError(
                    "cell_input_fn must be callable, saw type: %s" %
                    type(cell_input_fn).__name__)

        self._cell = cell
        self._seq_len = seq_len
        self._len_embeddings = len_embeddings
        self._cell_input_fn = cell_input_fn
        with ops.name_scope(name, "LenControlWrapperInit"):
            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (final_state_tensor.shape[0].value
                                    or array_ops.shape(final_state_tensor)[0])
                error_message = (
                    "When constructing LenControlWrapper %s: " %
                    self._base_name +
                    "Non-matching batch sizes between the memory "
                    "(encoder output) and initial_cell_state.  Are you using "
                    "the BeamSearchDecoder?  You may need to tile your initial state "
                    "via the tf.contrib.seq2seq.tile_batch function with argument "
                    "multiple=beam_width.")
                with ops.control_dependencies(
                        self._batch_size_checks(state_batch_size,
                                                error_message)):
                    self._initial_cell_state = nest.map_structure(
                        lambda s: array_ops.identity(
                            s, name="check_initial_cell_state"),
                        initial_cell_state)
예제 #13
0
    def __init__(self,
                 cell,
                 attention_mechanism,
                 output_layer,
                 max_oovs,
                 batch_size,
                 memory_full_vocab,
                 attention_layer_size=None,
                 alignment_history=False,
                 cell_input_fn=None,
                 output_attention=False,
                 output_generation_distribution=False,
                 output_copy_distribution=False,
                 output_combined_distribution=True,
                 initial_cell_state=None,
                 unk_id=None,
                 name=None):

        super(AttnPointerWrapper, self).__init__(name=name)
        if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
            raise TypeError(
                    "cell must be an RNNCell, saw type: %s" % type(cell).__name__)
        if isinstance(attention_mechanism, (list, tuple)):
            self._is_multi = True
            attention_mechanisms = attention_mechanism
            for attention_mechanism in attention_mechanisms:
                if not isinstance(attention_mechanism, AttentionMechanism):
                    raise TypeError(
                        "attention_mechanism must contain only instances of "
                        "AttentionMechanism, saw type: %s"
                        % type(attention_mechanism).__name__)
        else:
            self._is_multi = False
            if not isinstance(attention_mechanism, AttentionMechanism):
                raise TypeError(
                        "attention_mechanism must be an AttentionMechanism or list of "
                        "multiple AttentionMechanism instances, saw type: %s"
                        % type(attention_mechanism).__name__)
            attention_mechanisms = (attention_mechanism,)

        if cell_input_fn is None:
            cell_input_fn = (
                    lambda inputs, attention: array_ops.concat([inputs, attention], -1))
        else:
            if not callable(cell_input_fn):
                raise TypeError(
                        "cell_input_fn must be callable, saw type: %s"
                        % type(cell_input_fn).__name__)

        if attention_layer_size is not None:
            attention_layer_sizes = tuple(
                    attention_layer_size
                    if isinstance(attention_layer_size, (list, tuple))
                    else (attention_layer_size,))
            if len(attention_layer_sizes) != len(attention_mechanisms):
                raise ValueError(
                        "If provided, attention_layer_size must contain exactly one "
                        "integer per attention_mechanism, saw: %d vs %d"
                        % (len(attention_layer_sizes), len(attention_mechanisms)))
            self._attention_layers = tuple(
                    layers_core.Dense(
                            attention_layer_size, name="attention_layer", use_bias=False)
                    for attention_layer_size in attention_layer_sizes)
            self._attention_layer_size = sum(attention_layer_sizes)
        else:
            self._attention_layers = None
            self._attention_layer_size = sum(
                    attention_mechanism.values.get_shape()[-1].value
                    for attention_mechanism in attention_mechanisms)

        self._cell = cell
        self._attention_mechanisms = attention_mechanisms
        self._cell_input_fn = cell_input_fn
        self._output_attention = output_attention
        self._output_generation_distribution = output_generation_distribution
        self._output_copy_distribution = output_copy_distribution
        self._output_combined_distribution = output_combined_distribution
        self._unk_id = unk_id
        self._alignment_history = alignment_history
        self._output_layer = output_layer
        self._max_oovs = max_oovs
        self._batch_size = batch_size

        if memory_full_vocab is not None:
            self._memory_full_vocab = tuple(
                    memory_full_vocab
                    if isinstance(memory_full_vocab, (list, tuple))
                    else (memory_full_vocab, ))

            if len(self._memory_full_vocab) != len(attention_mechanisms):
                raise ValueError("memory full vocab must be same size as"
                        "attention mechanisms, saw %d vs %d" 
                        % (len(memory_full_vocab), len(attention_mechanisms)))

        if self._output_combined_distribution or \
                self._output_generation_distribution or \
                self._output_copy_distribution or \
                self._output_attention:
            assert self._output_combined_distribution ^\
                self._output_generation_distribution ^\
                self._output_copy_distribution ^\
                self._output_attention, "Can only output one type!"

        if self._output_combined_distribution or self._output_copy_distribution:
            assert self._unk_id is not None

        with ops.name_scope(name, "AttnPointerWrapperInit"):
            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (
                        final_state_tensor.shape[0].value
                        or array_ops.shape(final_state_tensor)[0])
                error_message = (
                        "When constructing AttnPointerWrapper %s: " % self._base_name +
                        "Non-matching batch sizes between the memory "
                        "(encoder output) and initial_cell_state.  Are you using "
                        "the BeamSearchDecoder?  You may need to tile your initial state "
                        "via the tf.contrib.seq2seq.tile_batch function with argument "
                        "multiple=beam_width.")
                with ops.control_dependencies(
                        self._batch_size_checks(state_batch_size, error_message)):
                    self._initial_cell_state = nest.map_structure(
                            lambda s: array_ops.identity(s, name="check_initial_cell_state"),
                            initial_cell_state)
예제 #14
0
파일: cells.py 프로젝트: carusyte/tflab
    def __init__(self,
                 cell,
                 input_keep_prob=1.0,
                 output_keep_prob=1.0,
                 state_keep_prob=1.0,
                 variational_recurrent=False,
                 input_size=None,
                 dtype=None,
                 seed=None,
                 dropout_state_filter_visitor=None):
        """Create a cell with added input, state, and/or output dropout.

        If `variational_recurrent` is set to `True` (**NOT** the default behavior),
        then the same dropout mask is applied at every step, as described in:

        Y. Gal, Z Ghahramani.  "A Theoretically Grounded Application of Dropout in
        Recurrent Neural Networks".  https://arxiv.org/abs/1512.05287

        Otherwise a different dropout mask is applied at every time step.

        Note, by default (unless a custom `dropout_state_filter` is provided),
        the memory state (`c` component of any `LSTMStateTuple`) passing through
        a `DropoutWrapper` is never modified.  This behavior is described in the
        above article.

        Args:
          cell: an RNNCell, a projection to output_size is added to it.
          input_keep_prob: unit Tensor or float between 0 and 1, input keep
            probability; if it is constant and 1, no input dropout will be added.
          output_keep_prob: unit Tensor or float between 0 and 1, output keep
            probability; if it is constant and 1, no output dropout will be added.
          state_keep_prob: unit Tensor or float between 0 and 1, output keep
            probability; if it is constant and 1, no output dropout will be added.
            State dropout is performed on the outgoing states of the cell.
            **Note** the state components to which dropout is applied when
            `state_keep_prob` is in `(0, 1)` are also determined by
            the argument `dropout_state_filter_visitor` (e.g. by default dropout
            is never applied to the `c` component of an `LSTMStateTuple`).
          variational_recurrent: Python bool.  If `True`, then the same
            dropout pattern is applied across all time steps per run call.
            If this parameter is set, `input_size` **must** be provided.
          input_size: (optional) (possibly nested tuple of) `TensorShape` objects
            containing the depth(s) of the input tensors expected to be passed in to
            the `DropoutWrapper`.  Required and used **iff**
             `variational_recurrent = True` and `input_keep_prob < 1`.
          dtype: (optional) The `dtype` of the input, state, and output tensors.
            Required and used **iff** `variational_recurrent = True`.
          seed: (optional) integer, the randomness seed.
          dropout_state_filter_visitor: (optional), default: (see below).  Function
            that takes any hierarchical level of the state and returns
            a scalar or depth=1 structure of Python booleans describing
            which terms in the state should be dropped out.  In addition, if the
            function returns `True`, dropout is applied across this sublevel.  If
            the function returns `False`, dropout is not applied across this entire
            sublevel.
            Default behavior: perform dropout on all terms except the memory (`c`)
            state of `LSTMCellState` objects, and don't try to apply dropout to
            `TensorArray` objects:
            ```
            def dropout_state_filter_visitor(s):
              if isinstance(s, LSTMCellState):
                # Never perform dropout on the c state.
                return LSTMCellState(c=False, h=True)
              elif isinstance(s, TensorArray):
                return False
              return True
            ```

        Raises:
          TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided
            but not `callable`.
          ValueError: if any of the keep_probs are not between 0 and 1.
        """
        if not rnn_cell_impl._like_rnncell(cell):
            raise TypeError("The parameter cell is not a RNNCell.")
        if (dropout_state_filter_visitor is not None
                and not callable(dropout_state_filter_visitor)):
            raise TypeError("dropout_state_filter_visitor must be callable")
        self._dropout_state_filter = (
            dropout_state_filter_visitor
            or rnn_cell_impl._default_dropout_state_filter_visitor)
        with ops.name_scope("DropoutWrapperInit"):

            def tensor_and_const_value(v):
                tensor_value = ops.convert_to_tensor(v)
                const_value = tensor_util.constant_value(tensor_value)
                return (tensor_value, const_value)

            for prob, attr in [(input_keep_prob, "input_keep_prob"),
                               (state_keep_prob, "state_keep_prob"),
                               (output_keep_prob, "output_keep_prob")]:
                tensor_prob, const_prob = tensor_and_const_value(prob)
                if const_prob is not None:
                    if const_prob < 0 or const_prob > 1:
                        raise ValueError(
                            "Parameter %s must be between 0 and 1: %d" %
                            (attr, const_prob))
                    setattr(self, "_%s" % attr, float(const_prob))
                else:
                    setattr(self, "_%s" % attr, tensor_prob)

        # Set cell, variational_recurrent, seed before running the code below
        self._cell = cell
        self._variational_recurrent = variational_recurrent
        self._seed = seed

        self._recurrent_input_noise = None
        self._recurrent_state_noise = None
        self._recurrent_output_noise = None

        if variational_recurrent:
            if dtype is None:
                raise ValueError(
                    "When variational_recurrent=True, dtype must be provided")

            def convert_to_batch_shape(s):
                # Prepend a 1 for the batch dimension; for recurrent
                # variational dropout we use the same dropout mask for all
                # batch elements.
                return array_ops.concat(
                    ([1], tensor_shape.TensorShape(s).as_list()), 0)

            def batch_noise(s, inner_seed):
                shape = convert_to_batch_shape(s)
                return random_ops.random_uniform(shape,
                                                 seed=inner_seed,
                                                 dtype=dtype)

            if (not isinstance(self._input_keep_prob, numbers.Real)
                    or self._input_keep_prob < 1.0):
                if input_size is None:
                    raise ValueError(
                        "When variational_recurrent=True and input_keep_prob < 1.0 or "
                        "is unknown, input_size must be provided")
                self._recurrent_input_noise = rnn_cell_impl._enumerated_map_structure_up_to(
                    input_size, lambda i, s: batch_noise(
                        s, inner_seed=self._gen_seed("input", i)), input_size)
            self._recurrent_state_noise = rnn_cell_impl._enumerated_map_structure_up_to(
                cell.state_size, lambda i, s: batch_noise(
                    s, inner_seed=self._gen_seed("state", i)), cell.state_size)
            self._recurrent_output_noise = rnn_cell_impl._enumerated_map_structure_up_to(
                cell.output_size, lambda i, s: batch_noise(
                    s, inner_seed=self._gen_seed("output", i)),
                cell.output_size)
예제 #15
0
def raw_rnn(cell,
            loop_fn,
            parallel_iterations=None,
            swap_memory=False,
            scope=None):
    """
    raw_rnn adapted from the original tensorflow implementation
    (https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py)
    to emit arbitrarily nested states for each time step (concatenated along the time axis)
    in addition to the outputs at each timestep and the final state

    returns (
        states for all timesteps,
        outputs for all timesteps,
        final cell state,
    )
    """
    if not _like_rnncell(cell):
        raise TypeError("cell must be an instance of RNNCell")
    if not callable(loop_fn):
        raise TypeError("loop_fn must be a callable")

    parallel_iterations = parallel_iterations or 32

    # Create a new scope in which the caching device is either
    # determined by the parent scope, or is set to place the cached
    # Variable using the same placement as for the rest of the RNN.
    with vs.variable_scope(scope or "rnn") as varscope:
        if not context.executing_eagerly():
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        time = constant_op.constant(0, dtype=dtypes.int32)
        (elements_finished, next_input, initial_state, emit_structure,
         init_loop_state) = loop_fn(time, None, None, None)
        flat_input = nest.flatten(next_input)

        # Need a surrogate loop state for the while_loop if none is available.
        loop_state = (init_loop_state if init_loop_state is not None else
                      constant_op.constant(0, dtype=dtypes.int32))

        input_shape = [input_.get_shape() for input_ in flat_input]
        static_batch_size = input_shape[0][0]

        for input_shape_i in input_shape:
            # Static verification that batch sizes all match
            static_batch_size.merge_with(input_shape_i[0])

        batch_size = static_batch_size.value
        const_batch_size = batch_size
        if batch_size is None:
            batch_size = array_ops.shape(flat_input[0])[0]

        nest.assert_same_structure(initial_state, cell.state_size)
        state = initial_state
        flat_state = nest.flatten(state)
        flat_state = [ops.convert_to_tensor(s) for s in flat_state]
        state = nest.pack_sequence_as(structure=state,
                                      flat_sequence=flat_state)

        if emit_structure is not None:
            flat_emit_structure = nest.flatten(emit_structure)
            flat_emit_size = [
                emit.shape
                if emit.shape.is_fully_defined() else array_ops.shape(emit)
                for emit in flat_emit_structure
            ]
            flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
        else:
            emit_structure = cell.output_size
            flat_emit_size = nest.flatten(emit_structure)
            flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

        flat_state_size = [
            s.shape if s.shape.is_fully_defined() else array_ops.shape(s)
            for s in flat_state
        ]
        flat_state_dtypes = [s.dtype for s in flat_state]

        flat_emit_ta = [
            tensor_array_ops.TensorArray(
                dtype=dtype_i,
                dynamic_size=True,
                element_shape=(tensor_shape.TensorShape([
                    const_batch_size
                ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))),
                size=0,
                name="rnn_output_%d" % i)
            for i, (dtype_i,
                    size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
        ]
        emit_ta = nest.pack_sequence_as(structure=emit_structure,
                                        flat_sequence=flat_emit_ta)
        flat_zero_emit = [
            array_ops.zeros(_concat(batch_size, size_i), dtype_i)
            for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)
        ]

        zero_emit = nest.pack_sequence_as(structure=emit_structure,
                                          flat_sequence=flat_zero_emit)

        flat_state_ta = [
            tensor_array_ops.TensorArray(
                dtype=dtype_i,
                dynamic_size=True,
                element_shape=(tensor_shape.TensorShape([
                    const_batch_size
                ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))),
                size=0,
                name="rnn_state_%d" % i)
            for i, (
                dtype_i,
                size_i) in enumerate(zip(flat_state_dtypes, flat_state_size))
        ]
        state_ta = nest.pack_sequence_as(structure=state,
                                         flat_sequence=flat_state_ta)

        def condition(unused_time, elements_finished, *_):
            return math_ops.logical_not(math_ops.reduce_all(elements_finished))

        def body(time, elements_finished, current_input, state_ta, emit_ta,
                 state, loop_state):
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output,
             next_loop_state) = loop_fn(next_time, next_output, cell_state,
                                        loop_state)

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                def copy_fn(cur_i, cand_i):
                    # TensorArray and scalar get passed through.
                    if isinstance(cur_i, tensor_array_ops.TensorArray):
                        return cand_i
                    if cur_i.shape.ndims == 0:
                        return cand_i
                    # Otherwise propagate the old or the new value.
                    with ops.colocate_with(cand_i):
                        return array_ops.where(elements_finished, cur_i,
                                               cand_i)

                return nest.map_structure(copy_fn, current, candidate)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit),
                                         emit_ta, emit_output)
            state_ta = nest.map_structure(
                lambda ta, state: ta.write(time, state), state_ta, next_state)

            elements_finished = math_ops.logical_or(elements_finished,
                                                    next_finished)

            return (next_time, elements_finished, next_input, state_ta,
                    emit_ta, next_state, loop_state)

        returned = control_flow_ops.while_loop(
            condition,
            body,
            loop_vars=[
                time, elements_finished, next_input, state_ta, emit_ta, state,
                loop_state
            ],
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory)

        (state_ta, emit_ta, final_state, final_loop_state) = returned[-4:]

        flat_states = nest.flatten(state_ta)
        flat_states = [
            array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states
        ]
        states = nest.pack_sequence_as(structure=state_ta,
                                       flat_sequence=flat_states)

        flat_outputs = nest.flatten(emit_ta)
        flat_outputs = [
            array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs
        ]
        outputs = nest.pack_sequence_as(structure=emit_ta,
                                        flat_sequence=flat_outputs)

        return (states, outputs, final_state)
예제 #16
0
  def __init__(self,
               cell,
               attention_mechanism,
               attention_layer_size=None,
               alignment_history=False,
               cell_input_fn=None,
               output_attention=True,
               initial_cell_state=None,
               name=None):
    """Construct the `AttentionWrapper`.

    Args:
      cell: An instance of `RNNCell`.
      attention_mechanism: An instance of `AttentionMechanism`.
      attention_layer_size: Python integer, the depth of the attention (output)
        layer. If None (default), use the context as attention at each time
        step. Otherwise, feed the context and cell output into the attention
        layer to generate attention at each time step.
      alignment_history: Python boolean, whether to store alignment history
        from all time steps in the final output state (currently stored as a
        time major `TensorArray` on which you must call `stack()`).
      cell_input_fn: (optional) A `callable`.  The default is:
        `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
      output_attention: Python bool.  If `True` (default), the output at each
        time step is the attention value.  This is the behavior of Luong-style
        attention mechanisms.  If `False`, the output at each time step is
        the output of `cell`.  This is the beahvior of Bhadanau-style
        attention mechanisms.  In both cases, the `attention` tensor is
        propagated to the next time step via the state and is used there.
        This flag only controls whether the attention mechanism is propagated
        up to the next cell in an RNN stack or to the top RNN output.
      initial_cell_state: The initial state value to use for the cell when
        the user calls `zero_state()`.  Note that if this value is provided
        now, and the user uses a `batch_size` argument of `zero_state` which
        does not match the batch size of `initial_cell_state`, proper
        behavior is not guaranteed.
      name: Name to use when creating ops.
    """
    super(AttentionWrapper, self).__init__(name=name)
    if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
      raise TypeError(
          "cell must be an RNNCell, saw type: %s" % type(cell).__name__)
    if not isinstance(attention_mechanism, AttentionMechanism):
      raise TypeError(
          "attention_mechanism must be a AttentionMechanism, saw type: %s"
          % type(attention_mechanism).__name__)
    if cell_input_fn is None:
      cell_input_fn = (
          lambda inputs, attention: array_ops.concat([inputs, attention], -1))
    else:
      if not callable(cell_input_fn):
        raise TypeError(
            "cell_input_fn must be callable, saw type: %s"
            % type(cell_input_fn).__name__)

    if attention_layer_size is not None:
      self._attention_layer = layers_core.Dense(
          attention_layer_size, name="attention_layer", use_bias=False)
      self._attention_layer_size = attention_layer_size
    else:
      self._attention_layer = None
      self._attention_layer_size = attention_mechanism.values.get_shape()[
          -1].value

    self._cell = cell
    self._attention_mechanism = attention_mechanism
    self._cell_input_fn = cell_input_fn
    self._output_attention = output_attention
    self._alignment_history = alignment_history
    with ops.name_scope(name, "AttentionWrapperInit"):
      if initial_cell_state is None:
        self._initial_cell_state = None
      else:
        final_state_tensor = nest.flatten(initial_cell_state)[-1]
        state_batch_size = (
            final_state_tensor.shape[0].value
            or array_ops.shape(final_state_tensor)[0])
        error_message = (
            "When constructing AttentionWrapper %s: " % self._base_name +
            "Non-matching batch sizes between the memory "
            "(encoder output) and initial_cell_state.  Are you using "
            "the BeamSearchDecoder?  You may need to tile your initial state "
            "via the tf.contrib.seq2seq.tile_batch function with argument "
            "multiple=beam_width.")
        with ops.control_dependencies(
            [check_ops.assert_equal(state_batch_size,
                                    self._attention_mechanism.batch_size,
                                    message=error_message)]):
          self._initial_cell_state = nest.map_structure(
              lambda s: array_ops.identity(s, name="check_initial_cell_state"),
              initial_cell_state)
예제 #17
0
  def __init__(self,
               cell,
               embedding,
               start_tokens,
               end_token,
               initial_state,
               beam_width,
               output_layer=None,
               length_penalty_weight=0.0):
    """Initialize the BeamSearchDecoder.

    Args:
      cell: An `RNNCell` instance.
      embedding: A callable that takes a vector tensor of `ids` (argmax ids),
        or the `params` argument for `embedding_lookup`.
      start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
      end_token: `int32` scalar, the token that marks end of decoding.
      initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
      beam_width:  Python integer, the number of beams.
      output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
        `tf.layers.Dense`.  Optional layer to apply to the RNN output prior
        to storing the result or sampling.
      length_penalty_weight: Float weight to penalize length. Disabled with 0.0.

    Raises:
      TypeError: if `cell` is not an instance of `RNNCell`,
        or `output_layer` is not an instance of `tf.layers.Layer`.
      ValueError: If `start_tokens` is not a vector or
        `end_token` is not a scalar.
    """
    if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
      raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
    if (output_layer is not None
        and not isinstance(output_layer, layers_base.Layer)):
      raise TypeError(
          "output_layer must be a Layer, received: %s" % type(output_layer))
    self._cell = cell
    self._output_layer = output_layer

    if callable(embedding):
      self._embedding_fn = embedding
    else:
      self._embedding_fn = (
          lambda ids: embedding_ops.embedding_lookup(embedding, ids))

    self._start_tokens = ops.convert_to_tensor(
        start_tokens, dtype=dtypes.int32, name="start_tokens")
    if self._start_tokens.get_shape().ndims != 1:
      raise ValueError("start_tokens must be a vector")
    self._end_token = ops.convert_to_tensor(
        end_token, dtype=dtypes.int32, name="end_token")
    if self._end_token.get_shape().ndims != 0:
      raise ValueError("end_token must be a scalar")

    self._batch_size = array_ops.size(start_tokens)
    self._beam_width = beam_width
    self._length_penalty_weight = length_penalty_weight
    self._initial_cell_state = nest.map_structure(
        self._maybe_split_batch_beams,
        initial_state, self._cell.state_size)
    self._start_tokens = array_ops.tile(
        array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
    self._start_inputs = self._embedding_fn(self._start_tokens)
    self._finished = array_ops.zeros(
        [self._batch_size, self._beam_width], dtype=dtypes.bool)
예제 #18
0
    def __init__(self,
                 cell,
                 attention_mechanism,
                 attention_layer_size=None,
                 alignment_history=False,
                 cell_input_fn=None,
                 output_attention=True,
                 initial_cell_state=None,
                 name=None):
        """Construct the `AttentionWrapper`.
		**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
		`AttentionWrapper`, then you must ensure that:
		- The encoder output has been tiled to `beam_width` via
			@{tf.contrib.seq2seq.tile_batch} (NOT `tf.tile`).
		- The `batch_size` argument passed to the `zero_state` method of this
			wrapper is equal to `true_batch_size * beam_width`.
		- The initial state created with `zero_state` above contains a
			`cell_state` value containing properly tiled final state from the
			encoder.
		An example:
		```
		tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
				encoder_outputs, multiplier=beam_width)
		tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch(
				encoder_final_state, multiplier=beam_width)
		tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
				sequence_length, multiplier=beam_width)
		attention_mechanism = MyFavoriteAttentionMechanism(
				num_units=attention_depth,
				memory=tiled_inputs,
				memory_sequence_length=tiled_sequence_length)
		attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
		decoder_initial_state = attention_cell.zero_state(
				dtype, batch_size=true_batch_size * beam_width)
		decoder_initial_state = decoder_initial_state.clone(
				cell_state=tiled_encoder_final_state)
		```
		Args:
			cell: An instance of `RNNCell`.
			attention_mechanism: A list of `AttentionMechanism` instances or a single
				instance.
			attention_layer_size: A list of Python integers or a single Python
				integer, the depth of the attention (output) layer(s). If None
				(default), use the context as attention at each time step. Otherwise,
				feed the context and cell output into the attention layer to generate
				attention at each time step. If attention_mechanism is a list,
				attention_layer_size must be a list of the same length.
			alignment_history: Python boolean, whether to store alignment history
				from all time steps in the final output state (currently stored as a
				time major `TensorArray` on which you must call `stack()`).
			cell_input_fn: (optional) A `callable`.  The default is:
				`lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
			output_attention: Python bool.  If `True` (default), the output at each
				time step is the attention value.  This is the behavior of Luong-style
				attention mechanisms.  If `False`, the output at each time step is
				the output of `cell`.  This is the beahvior of Bhadanau-style
				attention mechanisms.  In both cases, the `attention` tensor is
				propagated to the next time step via the state and is used there.
				This flag only controls whether the attention mechanism is propagated
				up to the next cell in an RNN stack or to the top RNN output.
			initial_cell_state: The initial state value to use for the cell when
				the user calls `zero_state()`.  Note that if this value is provided
				now, and the user uses a `batch_size` argument of `zero_state` which
				does not match the batch size of `initial_cell_state`, proper
				behavior is not guaranteed.
			name: Name to use when creating ops.
		Raises:
			TypeError: `attention_layer_size` is not None and (`attention_mechanism`
				is a list but `attention_layer_size` is not; or vice versa).
			ValueError: if `attention_layer_size` is not None, `attention_mechanism`
				is a list, and its length does not match that of `attention_layer_size`.
		"""
        super(AttentionWrapper, self).__init__(name=name)
        if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
            raise TypeError("cell must be an RNNCell, saw type: %s" %
                            type(cell).__name__)
        if isinstance(attention_mechanism, (list, tuple)):
            self._is_multi = True
            attention_mechanisms = attention_mechanism
            for attention_mechanism in attention_mechanisms:
                if not isinstance(attention_mechanism, AttentionMechanism):
                    raise TypeError(
                        "attention_mechanism must contain only instances of "
                        "AttentionMechanism, saw type: %s" %
                        type(attention_mechanism).__name__)
        else:
            self._is_multi = False
            if not isinstance(attention_mechanism, AttentionMechanism):
                raise TypeError(
                    "attention_mechanism must be an AttentionMechanism or list of "
                    "multiple AttentionMechanism instances, saw type: %s" %
                    type(attention_mechanism).__name__)
            attention_mechanisms = (attention_mechanism, )

        if cell_input_fn is None:
            cell_input_fn = (lambda inputs, attention: array_ops.concat(
                [inputs, attention], -1))
        else:
            if not callable(cell_input_fn):
                raise TypeError(
                    "cell_input_fn must be callable, saw type: %s" %
                    type(cell_input_fn).__name__)

        if attention_layer_size is not None:
            attention_layer_sizes = tuple(attention_layer_size if isinstance(
                attention_layer_size, (list,
                                       tuple)) else (attention_layer_size, ))
            if len(attention_layer_sizes) != len(attention_mechanisms):
                raise ValueError(
                    "If provided, attention_layer_size must contain exactly one "
                    "integer per attention_mechanism, saw: %d vs %d" %
                    (len(attention_layer_sizes), len(attention_mechanisms)))
            self._attention_layers = tuple(
                layers_core.Dense(attention_layer_size,
                                  name="attention_layer",
                                  use_bias=False)
                for attention_layer_size in attention_layer_sizes)
            self._attention_layer_size = sum(attention_layer_sizes)
        else:
            self._attention_layers = None
            self._attention_layer_size = sum(
                attention_mechanism.values.get_shape()[-1].value
                for attention_mechanism in attention_mechanisms)

        self._cell = cell
        self._attention_mechanisms = attention_mechanisms
        self._cell_input_fn = cell_input_fn
        self._output_attention = output_attention
        self._alignment_history = alignment_history
        with ops.name_scope(name, "AttentionWrapperInit"):
            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (final_state_tensor.shape[0].value
                                    or array_ops.shape(final_state_tensor)[0])
                error_message = (
                    "When constructing AttentionWrapper %s: " % self._base_name
                    + "Non-matching batch sizes between the memory "
                    "(encoder output) and initial_cell_state.  Are you using "
                    "the BeamSearchDecoder?  You may need to tile your initial state "
                    "via the tf.contrib.seq2seq.tile_batch function with argument "
                    "multiple=beam_width.")
                with ops.control_dependencies(
                        self._batch_size_checks(state_batch_size,
                                                error_message)):
                    self._initial_cell_state = nest.map_structure(
                        lambda s: array_ops.identity(
                            s, name="check_initial_cell_state"),
                        initial_cell_state)
    def __init__(self,
                 cell,
                 embedding,
                 start_tokens,
                 end_token,
                 initial_state,
                 beam_width,
                 output_layer=None,
                 length_penalty_weight=0.0):
        """Initialize BeamSearchDecoder.
        Args:
          cell: An `RNNCell` instance.
          embedding: A callable that takes a vector tensor of `ids` (argmax ids),
            or the `params` argument for `embedding_lookup`.
          start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
          end_token: `int32` scalar, the token that marks end of decoding.
          initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
          output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
            `tf.layers.Dense`.  Optional layer to apply to the RNN output prior
            to storing the result or sampling.
          length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
        Raises:
          TypeError: if `cell` is not an instance of `RNNCell`,
            or `output_layer` is not an instance of `tf.layers.Layer`.
          ValueError: If `start_tokens` is not a vector or
            `end_token` is not a scalar.
        """
        if not rnn_cell_impl._like_rnncell(cell):
            raise TypeError(
                "cell must be an RNNCell, received: %s" % type(cell))
        if (output_layer is not None
                and not isinstance(output_layer, layers_base.Layer)):
            raise TypeError(
                "output_layer must be a Layer, received: %s" % type(output_layer))
        self._cell = cell
        self._output_layer = output_layer

        if callable(embedding):
            self._embedding_fn = embedding
        else:
            self._embedding_fn = (
                lambda ids: embedding_ops.embedding_lookup(embedding, ids))

        self._start_tokens = ops.convert_to_tensor(
            start_tokens, dtype=dtypes.int32, name="start_tokens")
        if self._start_tokens.get_shape().ndims != 1:
            raise ValueError("start_tokens must be a vector")
        self._end_token = ops.convert_to_tensor(
            end_token, dtype=dtypes.int32, name="end_token")
        if self._end_token.get_shape().ndims != 0:
            raise ValueError("end_token must be a scalar")

        self._batch_size = tf.size(start_tokens)
        self._beam_width = beam_width
        self._length_penalty_weight = length_penalty_weight
        self._initial_cell_state = nest.map_structure(
            self._maybe_split_batch_beams,
            initial_state, self._cell.state_size)
        self._start_tokens = tf.tile(
            tf.expand_dims(self._start_tokens, 1), [1, self._beam_width])
        self._start_inputs = self._embedding_fn(self._start_tokens)
예제 #20
0
    def __init__(self,
                 cell,
                 attention_mechanism,
                 attention_layer_size=None,
                 alignment_history=False,
                 cell_input_fn=None,
                 output_attention=True,
                 initial_cell_state=None,
                 name=None):
        super(GatedAttentionWrapper, self).__init__(name=name)
        if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
            raise TypeError("cell must be an RNNCell, saw type: %s" %
                            type(cell).__name__)
        if not isinstance(attention_mechanism, AttentionMechanism):
            raise TypeError(
                "attention_mechanism must be a AttentionMechanism, saw type: %s"
                % type(attention_mechanism).__name__)
        if cell_input_fn is None:
            cell_input_fn = (lambda inputs, attention: array_ops.concat(
                [inputs, attention], -1))
        else:
            if not callable(cell_input_fn):
                raise TypeError(
                    "cell_input_fn must be callable, saw type: %s" %
                    type(cell_input_fn).__name__)

        if attention_layer_size is not None:
            self._attention_layer = layers_core.Dense(attention_layer_size,
                                                      name="attention_layer",
                                                      use_bias=False)
            self._attention_layer_size = attention_layer_size
        else:
            self._attention_layer = None
            self._attention_layer_size = attention_mechanism.values.get_shape(
            )[-1].value

        self._cell = cell
        self._attention_mechanism = attention_mechanism
        self._cell_input_fn = cell_input_fn
        self._output_attention = output_attention
        self._alignment_history = alignment_history
        with ops.name_scope(name, "AttentionWrapperInit"):
            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (final_state_tensor.shape[0].value
                                    or array_ops.shape(final_state_tensor)[0])
                error_message = (
                    "When constructing AttentionWrapper %s: " % self._base_name
                    + "Non-matching batch sizes between the memory "
                    "(encoder output) and initial_cell_state.  Are you using "
                    "the BeamSearchDecoder?  You may need to tile your initial state "
                    "via the tf.contrib.seq2seq.tile_batch function with argument "
                    "multiple=beam_width.")
                with ops.control_dependencies([
                        check_ops.assert_equal(
                            state_batch_size,
                            self._attention_mechanism.batch_size,
                            message=error_message)
                ]):
                    self._initial_cell_state = nest.map_structure(
                        lambda s: array_ops.identity(
                            s, name="check_initial_cell_state"),
                        initial_cell_state)
예제 #21
0
    def __init__(self,
                 cells,
                 initial_states,
                 is_training,
                 seq_len,
                 use_conv_lstm,
                 first_image,
                 input_sequence=None,
                 input_latent_sample_sequence=None,
                 initial_inputs=None,
                 output_layer=None,
                 autoregress=False,
                 reencode=False,
                 encoder_cnn=None,
                 decoder_cnn=None,
                 encoder_data_format="NCHW",
                 fixed_prior=False,
                 data_format="NCHW",
                 image_activation=None,
                 init_inference=False):
        """Initialize VariationalDecoder.
    Args:
      cells: Multiple `RNNCell` instances.
      initial_states: A (possibly nested tuple of...) tensors and TensorArrays.
        The initial states of the RNNCells.
      is_training: True if in training mode.
      seq_len: Desired output sequence length.
      use_conv_lstm: If True convolutional LSTM is used in cells argument.
      input_sequence: (Optional) Sequence of input embeddings that replaces
        autoregressive feeding of previous output as next input.
      input_latent_sample_sequence: (Optional) Sequence of input z latents that are fed to
        the decoder instead of sampling new ones.
      initial_inputs: (Optional) Initial input frame, if None first real
        input is used instead (and no prediction is made for first frame)
      output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
        `tf.layers.Dense`. Optional layer to apply to the RNN output prior
        to storing the result.
      autoregress: If true, the prior and decoder LSTMs observe the autoregressively
        generated latents. The inference LSTM still observes the ground truth, if any
      encoder_cnn: (optional) CNN used for autoregressive reencoding.
      decoder_cnn: (optional) CNN used for autoregressive reencoding.
      data_format: (optional) Desired data format of cnn_decoder,
        used to determine whether dimensions need to be resorted prior to
        reencoding in autoregression case. RNN always takes NHWC.
      init_inference: If True, inference network is initialized by running on the
        initial_input.
    Raises:
      TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
    """
        # super(VariationalDecoder, self).__init__(name=name)
        for _, cell in cells.items():
            if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
                raise TypeError("cell must be an RNNCell, received: %s" %
                                type(cell))
        if (output_layer is not None
                and not isinstance(output_layer, layers_base.Layer)):
            raise TypeError("output_layer must be a Layer, received: %s" %
                            type(output_layer))
        if (initial_inputs is None):
            raise ValueError(
                "Need to give either initial input or input sequence for Variational Decoder!"
            )

        self._cells = cells
        self._initial_states = initial_states
        self._is_training = is_training
        self._seq_len = seq_len
        self._use_conv_lstm = use_conv_lstm
        self._output_layer = output_layer
        self._fixed_prior = fixed_prior
        self._input_sequence = input_sequence
        self._batch_size = initial_inputs.get_shape().as_list()[0]
        self._use_cdna_model = tf.flags.FLAGS.use_cdna_model
        self._autoregress = autoregress
        self._cnn_encoder = encoder_cnn
        self._cnn_decoder = decoder_cnn
        self._encoder_data_format = encoder_data_format
        self._reencode = reencode
        self._input_latent_sample_sequence = input_latent_sample_sequence
        self._prev_inputs = initial_inputs
        self._data_format = data_format
        self._first_image = first_image

        if input_sequence is not None:
            self._input_seq_len = input_sequence.get_shape().as_list()[0]
            if not self._use_conv_lstm:
                self._input_sequence = tf.reshape(
                    self._input_sequence,
                    [self._input_seq_len, self._batch_size, -1])
            if self._input_seq_len != self._seq_len:
                tf.logging.warning([
                    'VariationalDecoder input sequence length and desired output length',
                    ' do not match. Is this desired? They are %d and %d' %
                    (self._input_seq_len, self._seq_len)
                ])
        else:
            self._input_seq_len = 0
        if input_latent_sample_sequence is not None:
            if input_latent_sample_sequence.get_shape().as_list(
            )[0] != self._seq_len:
                raise ValueError(
                    "Input Latent sequence must have the same length as the desired"
                    "output sequence")

        def get_size(output_size):
            # This is needed as there is an inconsistency between linear layers and LSTM
            # linear layers return an int, whereas LSTM returns TensorShape
            if isinstance(output_size, int):
                return output_size
            else:
                return output_size.as_list()[-1]

        self._inf_output_size = get_size(self._cells['inference'].output_size)
        self._prior_output_size = get_size(self._cells['prior'].output_size)
        self._sample_dim = int(self._inf_output_size * 0.5)
        # sanity check sample distribution dimensions
        if (self._inf_output_size % 2 != 0
                or self._inf_output_size != self._prior_output_size):
            raise ValueError(
                "Dimensions of Inference and Prior distribution are not valid, "
                "they are: %d, %d" %
                (self._inf_output_size, self._prior_output_size))
        if init_inference:
            # run inference network on initial_inputs to initialize
            _, self._initial_states['inference'] = \
                  self._cells['inference'](self._maybe_encode_inputs(initial_inputs),
                                           self._initial_states['inference'])

        if reencode and not autoregress:
            # (oleh) note that autoregression will also be executed when out of input frames
            # in which case this combination will be meaningful
            raise ValueError(
                "Reencoding is only supported with autoregression.")

        if image_activation is None:
            self._image_activation = lambda x: x
        else:
            self._image_activation = image_activation
예제 #22
0
    def __init__(self, cell, attention_mechanism, keep_prob, training, attention_layer_size=None, alignment_history=False, cell_input_fn=None,
                output_attention=True, initial_cell_state=None, name=None):
        super(MyAttentionWrapper, self).__init__(cell, attention_mechanism, attention_layer_size, alignment_history, cell_input_fn,
                output_attention, initial_cell_state, name)
        
        self.keep_prob = keep_prob
        self.training = training
        
        super(tf.contrib.seq2seq.AttentionWrapper, self).__init__(name=name)
        if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
            raise TypeError("cell must be an RNNCell, saw type: %s" % type(cell).__name__)
        if isinstance(attention_mechanism, (list, tuple)):
            self._is_multi = True
            attention_mechanisms = attention_mechanism
            for attention_mechanism in attention_mechanisms:
                if not isinstance(attention_mechanism, AttentionMechanism):
                    raise TypeError("attention_mechanism must contain only instances of "
                                    "AttentionMechanism, saw type: %s" % type(attention_mechanism).__name__)
        else:
            self._is_multi = False
            if not isinstance(attention_mechanism, tf.contrib.seq2seq.AttentionMechanism):
                raise TypeError("attention_mechanism must be an AttentionMechanism or list of "
                                "multiple AttentionMechanism instances, saw type: %s" % type(attention_mechanism).__name__)
            attention_mechanisms = (attention_mechanism,)

        if cell_input_fn is None:
            cell_input_fn = (lambda inputs, attention: tf.concat([inputs, attention], -1))
        else:
            if not callable(cell_input_fn):
                raise TypeError("cell_input_fn must be callable, saw type: %s" % type(cell_input_fn).__name__)

        if attention_layer_size is not None:
            attention_layer_sizes = tuple(attention_layer_size if isinstance(attention_layer_size, (list, tuple))
                                              else (attention_layer_size,))
            if len(attention_layer_sizes) != len(attention_mechanisms):
                raise ValueError("If provided, attention_layer_size must contain exactly one "
                                "integer per attention_mechanism, saw: %d vs %d" % (len(attention_layer_sizes), len(attention_mechanisms)))
            self._attention_layers = tuple(layers_core.Dense(attention_layer_size, name="attention_layer", use_bias=True,
                    activation=tf.tanh) for attention_layer_size in attention_layer_sizes)
            self._attention_dropout_layers = tuple(layers_core.Dropout(rate=1-self.keep_prob, name="attention_dropout_layer")
                                                  for attention_layer_size in attention_layer_sizes)
            self._attention_layer_size = sum(attention_layer_sizes)
        else:
            self._attention_layers = None
            self._attention_dropout_layers = None
            self._attention_layer_size = sum(attention_mechanism.values.get_shape()[-1].value
                                                      for attention_mechanism in attention_mechanisms)
            
        self._cell = cell
        self._attention_mechanisms = attention_mechanisms
        self._cell_input_fn = cell_input_fn
        self._output_attention = output_attention
        self._alignment_history = alignment_history
        with tf.name_scope(name, "AttentionWrapperInit"):
            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (final_state_tensor.shape[0].value or tf.shape(final_state_tensor)[0])
                error_message = ('custom error msg:0')
                with tf.control_dependencies(
                    self._batch_size_checks(state_batch_size, error_message)):
                    self._initial_cell_state = nest.map_structure(lambda s: tf.identity(s, name="check_initial_cell_state"),
                                                              initial_cell_state)
        self.attention_layer = layers_core.Dense(512, activation=tf.tanh)
예제 #23
0
    def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
                            source_sequence_length):
        """Build a RNN cell with attention mechanism that can be used by decoder."""
        attention_option = hparams.attention
        attention_architecture = hparams.attention_architecture

        # if attention_architecture != "joint":
        #   raise ValueError(
        #       "Unknown attention architecture %s" % attention_architecture)

        num_units = hparams.num_units
        num_layers = hparams.num_layers
        num_residual_layers = hparams.num_residual_layers
        num_gpus = hparams.num_gpus
        beam_width = hparams.beam_width

        dtype = tf.float32

        # Ensure memory is batch-major
        if self.time_major:
            memory = tf.transpose(encoder_outputs, [1, 0, 2])
        else:
            memory = encoder_outputs

        if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0:
            memory = tf.contrib.seq2seq.tile_batch(memory,
                                                   multiplier=beam_width)
            source_sequence_length = tf.contrib.seq2seq.tile_batch(
                source_sequence_length, multiplier=beam_width)
            encoder_state = tf.contrib.seq2seq.tile_batch(
                encoder_state, multiplier=beam_width)
            batch_size = self.batch_size * beam_width
        else:
            batch_size = self.batch_size

        attention_mechanism = self.attention_mechanism_fn(
            attention_option, num_units, memory, source_sequence_length,
            self.mode)

        cell = model_helper.create_rnn_cell(
            unit_type=hparams.unit_type,
            num_units=num_units,
            num_layers=num_layers,
            num_residual_layers=num_residual_layers,
            forget_bias=hparams.forget_bias,
            dropout=hparams.dropout,
            num_gpus=num_gpus,
            mode=self.mode,
            single_cell_fn=self.single_cell_fn)

        # Only generate alignment in greedy INFER mode.
        alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER
                             and beam_width == 0)
        cell = s2s.AttentionWrapper(
            cell,
            attention_mechanism,
            attention_layer_size=num_units,
            alignment_history=alignment_history,
            name="joint_attention",
            attention_architecture=hparams.attention_architecture,
            output_layer=self.output_layer,
            training=(self.mode != tf.contrib.learn.ModeKeys.INFER))

        print("Tyoe of cell is", cell, rnn_cell_impl._like_rnncell(cell),
              hasattr(cell, "state_size"))

        cell = tf.contrib.rnn.DeviceWrapper(
            cell, model_helper.get_device_str(num_layers - 1, num_gpus))

        print("Tyoe of cell1 is", cell, rnn_cell_impl._like_rnncell(cell))

        if hparams.pass_hidden_state:
            decoder_initial_state = cell.zero_state(
                batch_size, dtype).clone(cell_state=encoder_state)
        else:
            decoder_initial_state = cell.zero_state(batch_size, dtype)

        return cell, decoder_initial_state
예제 #24
0
def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None):
    """
    raw_rnn adapted from the original tensorflow implementation
    (https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py)
    to emit arbitrarily nested states for each time step (concatenated along the time axis)
    in addition to the outputs at each timestep and the final state

    returns (
        states for all timesteps,
        outputs for all timesteps,
        final cell state,
    )
    """
    if not _like_rnncell(cell):
        raise TypeError("cell must be an instance of RNNCell")
    if not callable(loop_fn):
        raise TypeError("loop_fn must be a callable")

    parallel_iterations = parallel_iterations or 32

    # Create a new scope in which the caching device is either
    # determined by the parent scope, or is set to place the cached
    # Variable using the same placement as for the rest of the RNN.
    with vs.variable_scope(scope or "rnn") as varscope:
        if context.in_graph_mode():
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        time = constant_op.constant(0, dtype=dtypes.int32)
        (elements_finished, next_input, initial_state, emit_structure,
         init_loop_state) = loop_fn(time, None, None, None)
        flat_input = nest.flatten(next_input)

        # Need a surrogate loop state for the while_loop if none is available.
        loop_state = (init_loop_state if init_loop_state is not None
                      else constant_op.constant(0, dtype=dtypes.int32))

        input_shape = [input_.get_shape() for input_ in flat_input]
        static_batch_size = input_shape[0][0]

        for input_shape_i in input_shape:
            # Static verification that batch sizes all match
            static_batch_size.merge_with(input_shape_i[0])

        batch_size = static_batch_size.value
        const_batch_size = batch_size
        if batch_size is None:
            batch_size = array_ops.shape(flat_input[0])[0]

        nest.assert_same_structure(initial_state, cell.state_size)
        state = initial_state
        flat_state = nest.flatten(state)
        flat_state = [ops.convert_to_tensor(s) for s in flat_state]
        state = nest.pack_sequence_as(structure=state,
                                      flat_sequence=flat_state)

        if emit_structure is not None:
            flat_emit_structure = nest.flatten(emit_structure)
            flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
                              array_ops.shape(emit) for emit in flat_emit_structure]
            flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
        else:
            emit_structure = cell.output_size
            flat_emit_size = nest.flatten(emit_structure)
            flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

        flat_state_size = [s.shape if s.shape.is_fully_defined() else
                           array_ops.shape(s) for s in flat_state]
        flat_state_dtypes = [s.dtype for s in flat_state]

        flat_emit_ta = [
            tensor_array_ops.TensorArray(
                dtype=dtype_i,
                dynamic_size=True,
                element_shape=(tensor_shape.TensorShape([const_batch_size])
                               .concatenate(_maybe_tensor_shape_from_tensor(size_i))),
                size=0,
                name="rnn_output_%d" % i
            )
            for i, (dtype_i, size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
        ]
        emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_emit_ta)
        flat_zero_emit = [
            array_ops.zeros(_concat(batch_size, size_i), dtype_i)
            for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)]

        zero_emit = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_zero_emit)

        flat_state_ta = [
            tensor_array_ops.TensorArray(
                dtype=dtype_i,
                dynamic_size=True,
                element_shape=(tensor_shape.TensorShape([const_batch_size])
                               .concatenate(_maybe_tensor_shape_from_tensor(size_i))),
                size=0,
                name="rnn_state_%d" % i
            )
            for i, (dtype_i, size_i) in enumerate(zip(flat_state_dtypes, flat_state_size))
        ]
        state_ta = nest.pack_sequence_as(structure=state, flat_sequence=flat_state_ta)

        def condition(unused_time, elements_finished, *_):
            return math_ops.logical_not(math_ops.reduce_all(elements_finished))

        def body(time, elements_finished, current_input, state_ta, emit_ta, state, loop_state):
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output,
             next_loop_state) = loop_fn(next_time, next_output, cell_state, loop_state)

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                def copy_fn(cur_i, cand_i):
                    # TensorArray and scalar get passed through.
                    if isinstance(cur_i, tensor_array_ops.TensorArray):
                        return cand_i
                    if cur_i.shape.ndims == 0:
                        return cand_i
                    # Otherwise propagate the old or the new value.
                    with ops.colocate_with(cand_i):
                        return array_ops.where(elements_finished, cur_i, cand_i)
                return nest.map_structure(copy_fn, current, candidate)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)
            state_ta = nest.map_structure(lambda ta, state: ta.write(time, state), state_ta, next_state)

            elements_finished = math_ops.logical_or(elements_finished, next_finished)

            return (next_time, elements_finished, next_input, state_ta,
                    emit_ta, next_state, loop_state)

        returned = control_flow_ops.while_loop(
            condition, body, loop_vars=[
                time, elements_finished, next_input, state_ta,
                emit_ta, state, loop_state],
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory
        )

        (state_ta, emit_ta, final_state, final_loop_state) = returned[-4:]

        flat_states = nest.flatten(state_ta)
        flat_states = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states]
        states = nest.pack_sequence_as(structure=state_ta, flat_sequence=flat_states)

        flat_outputs = nest.flatten(emit_ta)
        flat_outputs = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs]
        outputs = nest.pack_sequence_as(structure=emit_ta, flat_sequence=flat_outputs)

        return (states, outputs, final_state)