Пример #1
0
    def _time_step(time, output_ta_t, state):
        input_t = tuple(ta.read(time) for ta in input_ta)
        # Restore some shape information
        for input_, shape in zip(input_t, inputs_got_shape):
            input_.set_shape(shape[1:])

        input_t = nest.pack_sequence_as(structure=inputs,
                                        flat_sequence=input_t)
        call_cell = lambda: cell([input_t, gate_vector], state
                                 )  # The call function!
        if sequence_length is not None:
            (output, new_state) = rnn._rnn_step(
                time=time,
                sequence_length=sequence_length,
                min_sequence_length=min_sequence_length,
                max_sequence_length=max_sequence_length,
                zero_output=zero_output,
                state=state,
                call_cell=call_cell,
                state_size=state_size,
                skip_conditionals=True)
        else:  # This gets executed for us (UKP)
            (output, new_state) = call_cell()

        # Pack state if using state tuples
        output = nest.flatten(output)

        output_ta_t = tuple(
            ta.write(time, out) for ta, out in zip(output_ta_t, output))

        return (time + 1, output_ta_t, new_state)
Пример #2
0
    def _time_step(time, state, output_ta_t):
        """Take a time step of the dynamic RNN decoder.
        """
        input_t = input_ta.read(time)
        input_t.set_shape([const_batch_size, const_depth])
        if loop_function is not None:
            output_t_minus_1 = output_ta_t.read(time - 1)
            output_t_minus_1.set_shape([const_batch_size, cell.output_size])
            input_t = loop_function(input_t, output_t_minus_1, time)

        with variable_scope.variable_scope(scope, reuse=True):
            _output, _state = cell(input_t, state)

        call_cell = lambda: cell(input_t, state)
        if sequence_length is not None:
            # copy_cond = (time>=sequence_length)
            # new_output = math_ops.select(copy_cond, zeros_ouput, _output)
            # new_state = math_ops.select(copy_cond, state, _state)
            (new_output, new_state) = rnn._rnn_step(
                time=time,
                sequence_length=sequence_length,
                min_sequence_length=math_ops.reduce_min(sequence_length),
                max_sequence_length=math_ops.reduce_max(sequence_length),
                zero_output=zeros_ouput,
                state=state,
                call_cell=call_cell,
                state_size=cell.state_size,
                skip_conditionals=True)
        else:
            new_output, new_state = _output, _state

        output_ta_t = output_ta_t.write(time, new_output)

        return (time + 1, new_state, output_ta_t)
    def _time_step(time, output_ta_t, states):
        """Take a time step of the dynamic RNN.

    Args:
      time: int32 scalar Tensor.
      output_ta_t: List of `TensorArray`s that represent the output.
      state: nested tuple of vector tensors that represent the state.

    Returns:
      The tuple (time + 1, output_ta_t with updated flow, new_state).
    """
        if in_graph_mode:
            input_t = tuple(ta.read(time) for ta in input_ta)
            # Restore some shape information
            for input_, shape in zip(input_t, inputs_got_shape):
                input_.set_shape(shape[1:])
        else:
            input_t = tuple(ta[time.numpy()] for ta in input_ta)

        input_t = nest.pack_sequence_as(structure=inputs,
                                        flat_sequence=input_t)
        row_time = time % width
        states_1 = control_flow_ops.cond(math_ops.less(time, width),
                                         lambda: states.read(time_steps),
                                         lambda: states.read(time - width))
        states_2 = control_flow_ops.cond(math_ops.equal(row_time, 0),
                                         lambda: states.read(time_steps),
                                         lambda: states.read(time - 1))

        call_cell = lambda: cell(input_t, tuple((states_1, states_2)))
        if sequence_length is not None:
            (output,
             new_state) = _rnn_step(time=row_time,
                                    sequence_length=sequence_length,
                                    min_sequence_length=min_sequence_length,
                                    max_sequence_length=max_sequence_length,
                                    zero_output=zero_output,
                                    state=state,
                                    call_cell=call_cell,
                                    state_size=state_size,
                                    skip_conditionals=True)
        else:
            (output, new_state) = call_cell()

        # Pack state if using state tuples
        output = nest.flatten(output)
        if in_graph_mode:
            output_ta_t = tuple(
                ta.write(time, out) for ta, out in zip(output_ta_t, output))
            states = states.write(time, new_state)
        else:
            for ta, out in zip(output_ta_t, output):
                ta[time.numpy()] = out

        return (time + 1, output_ta_t, states)
Пример #4
0
    def _time_step(time, output_ta_t, state):
        """Take a time step of the dynamic RNN.

    Args:
      time: int32 scalar Tensor.
      output_ta_t: List of `TensorArray`s that represent the output.
      state: nested tuple of vector tensors that represent the state.

    Returns:
      The tuple (time + 1, output_ta_t with updated flow, new_state).
    """

        input_t = tuple(ta.read(time) for ta in input_ta)
        # Restore some shape information
        for input_, shape in zip(input_t, inputs_got_shape):
            input_.set_shape(shape[1:])

        input_t = nest.pack_sequence_as(structure=inputs,
                                        flat_sequence=input_t)
        call_cell = lambda: cell(input_t, state)

        def f1():
            return zero_output

        def f2():
            return tuple(
                ta.read(tf.subtract(time, 1))
                for ta in output_ta_t)  #output_ta_t.read(tf.subtract(time, 1))

        cur_zero_output = tf.cond(tf.less(time, 1), f1, f2)

        if sequence_length is not None:
            (output, new_state) = rnn._rnn_step(
                time=time,
                sequence_length=sequence_length,
                min_sequence_length=min_sequence_length,
                max_sequence_length=max_sequence_length,
                zero_output=cur_zero_output,  # TODO
                state=state,
                call_cell=call_cell,
                state_size=state_size,
                skip_conditionals=True)
        else:
            (output, new_state) = call_cell()

        # Pack state if using state tuples
        output = nest.flatten(output)

        output_ta_t = tuple(
            ta.write(time, out) for ta, out in zip(output_ta_t, output))

        return (time + 1, output_ta_t, new_state)
Пример #5
0
        def _time_step(time, input_, state, output, proj_outputs, decoder_outputs, samples, states, weights,
                       prev_weights):
            context_vector, new_weights = attention_(state, prev_weights=prev_weights)
            weights = weights.write(time, new_weights)

            # FIXME use `output` or `state` here?
            output_ = linear_unsafe([state, input_, context_vector], decoder.cell_size, False, scope='maxout')
            output_ = tf.reduce_max(tf.reshape(output_, tf.stack([batch_size, decoder.cell_size // 2, 2])), axis=2)
            output_ = linear_unsafe(output_, decoder.embedding_size, False, scope='softmax0')
            decoder_outputs = decoder_outputs.write(time, output_)
            output_ = linear_unsafe(output_, output_size, True, scope='softmax1')
            proj_outputs = proj_outputs.write(time, output_)

            argmax = lambda: tf.argmax(output_, 1)
            softmax = lambda: tf.squeeze(tf.multinomial(tf.log(tf.nn.softmax(output_)), num_samples=1),
                                         axis=1)
            target = lambda: inputs.read(time + 1)

            sample = tf.case([
                (tf.logical_and(time < time_steps - 1, tf.random_uniform([]) >= feed_previous), target),
                (tf.logical_not(feed_argmax), softmax)],
                default=argmax)   # default case is useful for beam-search

            sample.set_shape([None])
            sample = tf.stop_gradient(sample)

            samples = samples.write(time, sample)
            input_ = embed(sample)

            x = tf.concat([input_, context_vector], 1)
            call_cell = lambda: unsafe_decorator(cell)(x, state)

            if sequence_length is not None:
                new_output, new_state = rnn._rnn_step(
                    time=time,
                    sequence_length=sequence_length,
                    min_sequence_length=min_sequence_length,
                    max_sequence_length=max_sequence_length,
                    zero_output=zero_output,
                    state=state,
                    call_cell=call_cell,
                    state_size=state_size,
                    skip_conditionals=True)
            else:
                new_output, new_state = call_cell()

            states = states.write(time, new_state)

            return (time + 1, input_, new_state, new_output, proj_outputs, decoder_outputs, samples, states, weights,
                    new_weights)
Пример #6
0
    def _time_step(time, output_ta_t, state):
        """Take a time step of the dynamic RNN.
        Args:
          time: int32 scalar Tensor.
          output_ta_t: List of `TensorArray`s that represent the output.
          state: nested tuple of vector tensors that represent the state.
        Returns:
          The tuple (time + 1, output_ta_t with updated flow, new_state).
        """

        if in_graph_mode:
            input_t = tuple(ta.read(time) for ta in input_ta)
            # Restore some shape information
            for input_, shape in zip(input_t, inputs_got_shape):
                input_.set_shape(shape[1:])
        else:
            input_t = tuple(ta[time.numpy()] for ta in input_ta)

        input_t = nest.pack_sequence_as(structure=inputs,
                                        flat_sequence=input_t)
        #Here, we make the change to add 'time' as input when calling the cell.
        call_cell = lambda: cell(input_t, state, time)

        if sequence_length is not None:
            (output, new_state) = rnn._rnn_step(
                time=time,
                sequence_length=sequence_length,
                min_sequence_length=min_sequence_length,
                max_sequence_length=max_sequence_length,
                zero_output=zero_output,
                state=state,
                call_cell=call_cell,
                state_size=state_size,
                skip_conditionals=True)
        else:
            (output, new_state) = call_cell()

        # Pack state if using state tuples
        output = nest.flatten(output)

        if in_graph_mode:
            output_ta_t = tuple(
                ta.write(time, out) for ta, out in zip(output_ta_t, output))
        else:
            for ta, out in zip(output_ta_t, output):
                ta[time.numpy()] = out

        return (time + 1, output_ta_t, new_state)
Пример #7
0
  def _time_step(time, output_ta_t, state):
    """Take a time step of the dynamic RNN.

    Args:
      time: int32 scalar Tensor.
      output_ta_t: List of `TensorArray`s that represent the output.
      state: nested tuple of vector tensors that represent the state.

    Returns:
      The tuple (time + 1, output_ta_t with updated flow, new_state).
    """

    input_t = tuple(ta.read(time) for ta in input_ta)
    # Restore some shape information
    for input_, shape in zip(input_t, inputs_got_shape):
      input_.set_shape(shape[1:])

    input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)
    call_cell = lambda: cell(input_t, state)

    def f1(): return zero_output
    def f2(): return tuple(ta.read(tf.subtract(time, 1)) for ta in output_ta_t)#output_ta_t.read(tf.subtract(time, 1))
    cur_zero_output = tf.cond(tf.less(time, 1), f1, f2)

    if sequence_length is not None:
      (output, new_state) = rnn._rnn_step(
          time=time,
          sequence_length=sequence_length,
          min_sequence_length=min_sequence_length,
          max_sequence_length=max_sequence_length,
          zero_output=cur_zero_output, # TODO
          state=state,
          call_cell=call_cell,
          state_size=state_size,
          skip_conditionals=True)
    else:
      (output, new_state) = call_cell()

    # Pack state if using state tuples
    output = nest.flatten(output)

    output_ta_t = tuple(ta.write(time, out) for ta, out in zip(output_ta_t, output))

    return (time + 1, output_ta_t, new_state)
Пример #8
0
    def _time_step(time, output_ta_t, state_ta_t, state):
        """
        Take a time step of the dynamic RNN.
        """
        input_t = tuple(ta.read(time) for ta in input_ta)
        # Restore some shape information
        for input_, shape in zip(input_t, inputs_got_shape):
            input_.set_shape(shape[1:])

        input_t = nest.pack_sequence_as(structure=inputs,
                                        flat_sequence=input_t)
        call_cell = lambda: cell(input_t, state)

        if sequence_length is not None:
            (output,
             new_state) = _rnn_step(time=time,
                                    sequence_length=sequence_length,
                                    min_sequence_length=min_sequence_length,
                                    max_sequence_length=max_sequence_length,
                                    zero_output=zero_output,
                                    state=state,
                                    call_cell=call_cell,
                                    state_size=state_size,
                                    skip_conditionals=True)
        else:
            (output, new_state) = call_cell()

        # Pack state if using state tuples
        output = nest.flatten(output)
        new_state_ = nest.flatten(new_state)

        output_ta_t = tuple(
            ta.write(time, out) for ta, out in zip(output_ta_t, output))
        state_ta_t = tuple(
            ta.write(time, out) for ta, out in zip(state_ta_t, new_state_))

        return (time + 1, output_ta_t, state_ta_t, new_state)
Пример #9
0
        def _time_step(time, state, _, attn_weights, output_ta_t, attn_weights_ta_t):
            input_t = input_ta.read(time)
            # restore some shape information
            r = tf.random_uniform([])
            input_t = tf.cond(tf.logical_and(time > 0, r < feed_previous),
                              lambda: tf.stop_gradient(extract_argmax_and_embed(output_ta_t.read(time - 1))),
                              lambda: input_t)
            input_t.set_shape(decoder_inputs.get_shape()[1:])
            # the code from TensorFlow used a concatenation of input_t and attns as input here
            # TODO: evaluate the impact of this
            call_cell = lambda: unsafe_decorator(cell)(input_t, state)

            if sequence_length is not None:
                output, new_state = rnn._rnn_step(
                    time=time,
                    sequence_length=sequence_length,
                    min_sequence_length=min_sequence_length,
                    max_sequence_length=max_sequence_length,
                    zero_output=zero_output,
                    state=state,
                    call_cell=call_cell,
                    state_size=state_size,
                    skip_conditionals=True)
            else:
                output, new_state = call_cell()

            attn_weights_ta_t = attn_weights_ta_t.write(time, attn_weights)
            # using decoder state instead of decoder output in the attention model seems
            # to give much better results
            new_attns, new_attn_weights = attention_(new_state, prev_weights=attn_weights)

            with tf.variable_scope('attention_output_projection'):  # this can take a lot of memory
                output = linear_unsafe([output, new_attns], output_size, True)

            output_ta_t = output_ta_t.write(time, output)
            return time + 1, new_state, new_attns, new_attn_weights, output_ta_t, attn_weights_ta_t
Пример #10
0
    def _time_step(time, output_ta_t, state_ta_t, state):
        """
        Take a time step of the dynamic RNN.
        """
        input_t = tuple(ta.read(time) for ta in input_ta)
        # Restore some shape information
        for input_, shape in zip(input_t, inputs_got_shape):
            input_.set_shape(shape[1:])

        input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)
        call_cell = lambda: cell(input_t, state)

        if sequence_length is not None:
            (output, new_state) = _rnn_step(
                time=time,
                sequence_length=sequence_length,
                min_sequence_length=min_sequence_length,
                max_sequence_length=max_sequence_length,
                zero_output=zero_output,
                state=state,
                call_cell=call_cell,
                state_size=state_size,
                skip_conditionals=True)
        else:
            (output, new_state) = call_cell()

        # Pack state if using state tuples
        output = nest.flatten(output)
        new_state_ = nest.flatten(new_state)

        output_ta_t = tuple(
            ta.write(time, out) for ta, out in zip(output_ta_t, output))
        state_ta_t = tuple(
            ta.write(time, out) for ta, out in zip(state_ta_t, new_state_))

        return (time + 1, output_ta_t, state_ta_t, new_state)
Пример #11
0
def custom_rnn(cell, inputs, initial_state=None, dtype=None,
        sequence_length=None, scope=None):
  """Creates a recurrent neural network specified by RNNCell "cell".

  The simplest form of RNN network generated is:
    state = cell.zero_state(...)
    outputs = []
    for input_ in inputs:
      output, state = cell(input_, state)
      outputs.append(output)
    return (outputs, state)

  However, a few other options are available:

  An initial state can be provided.
  If the sequence_length vector is provided, dynamic calculation is performed.
  This method of calculation does not compute the RNN steps past the maximum
  sequence length of the minibatch (thus saving computational time),
  and properly propagates the state at an example's sequence length
  to the final state output.

  The dynamic calculation performed is, at time t for batch row b,
    (output, state)(b, t) =
      (t >= sequence_length(b))
        ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
        : cell(input(b, t), state(b, t - 1))

  Args:
    cell: An instance of RNNCell.
    inputs: A length T list of inputs, each a tensor of shape
      [batch_size, cell.input_size].
    initial_state: (optional) An initial state for the RNN.  This must be
      a tensor of appropriate type and shape [batch_size x cell.state_size].
    dtype: (optional) The data type for the initial state.  Required if
      initial_state is not provided.
    sequence_length: Specifies the length of each sequence in inputs.
      An int32 or int64 vector (tensor) size [batch_size].  Values in [0, T).
    scope: VariableScope for the created subgraph; defaults to "RNN".

  Returns:
    A pair (outputs, state) where:
      outputs is a length T list of outputs (one for each input)
      state is the final state

  Raises:
    TypeError: If "cell" is not an instance of RNNCell.
    ValueError: If inputs is None or an empty list.
  """

  if not isinstance(cell, tf.nn.rnn_cell.RNNCell):
    raise TypeError("cell must be an instance of RNNCell")
  if not isinstance(inputs, list):
    raise TypeError("inputs must be a list")
  if not inputs:
    raise ValueError("inputs must not be empty")

  outputs = []
  # Create a new scope in which the caching device is either
  # determined by the parent scope, or is set to place the cached
  # Variable using the same placement as for the rest of the RNN.
  with tf.variable_scope(scope or "RNN") as varscope:
    fixed_batch_size = inputs[0].get_shape().with_rank_at_least(1)[0]
    if fixed_batch_size.value:
      batch_size = fixed_batch_size.value
    else:
      batch_size = array_ops.shape(inputs[0])[0]
    if initial_state is not None:
      state = initial_state
    else:
      if not dtype:
        raise ValueError("If no initial_state is provided, dtype must be.")
      state = cell.zero_state(batch_size, dtype)

    if sequence_length is not None:
      sequence_length = tf.to_int32(sequence_length)

    if sequence_length:  # Prepare variables
      zero_output = tf.zeros(
          tf.pack([batch_size, cell.output_size]), inputs[0].dtype)
      zero_output.set_shape(
          tf.TensorShape([fixed_batch_size.value, cell.output_size]))
      min_sequence_length = tf.reduce_min(sequence_length)
      max_sequence_length = tf.reduce_max(sequence_length)

    c1h1 = []

    for time, input_ in enumerate(inputs):
      if time > 0: tf.get_variable_scope().reuse_variables()
      # pylint: disable=cell-var-from-loop
      call_cell = lambda: cell(input_, state)
      # pylint: enable=cell-var-from-loop
      if sequence_length:
        (output, state) = _rnn_step(
            time, sequence_length, min_sequence_length, max_sequence_length,
            zero_output, state, call_cell)
      else:
        (output, state) = call_cell()
      if time==0:
        c1h1 = state
      outputs.append(output)

    return (outputs, state, c1h1)
def rnn_with_output_feedback(cell,
                             inputs,
                             targets1,
                             targets1_num_symbols,
                             target1_emb_size,
                             target1_output_projection,
                             targets2,
                             targets2_num_symbols,
                             target2_emb_size,
                             target2_output_projection,
                             word_emb_size,
                             DNN_at_output,
                             zero_intent_thres=0,
                             sequence_length=None,
                             dtype=None,
                             train_with_true_label=True,
                             use_predicted_output=False):
    '''
  zero_intent_thres:  int, the intent contribution to context remain zero before this thres, 
                      and linear increase to 1 after that.
  '''
    if not isinstance(cell, tf.contrib.rnn.RNNCell):
        raise TypeError("cell must be an instance of RNNCell")
    if not isinstance(inputs, list):
        raise TypeError("inputs must be a list")
    if not isinstance(targets1, list):
        raise TypeError("targets1 must be a list")
    if not isinstance(targets2, list):
        raise TypeError("targets2 must be a list")
    if not inputs:
        raise ValueError("inputs must not be empty")
    if not dtype:
        raise ValueError(
            "dtype must be provided, which is to used in defining intial RNN state"
        )

    encoder_outputs = []
    intent_embedding = variable_scope.get_variable(
        "intent_embedding", [targets1_num_symbols, target1_emb_size])
    tag_embedding = variable_scope.get_variable(
        "tag_embedding", [targets2_num_symbols, target2_emb_size])
    # use predicted label if use_predicted_output during inference, use true label during training
    # To choose to always use predicted label, disable the if condition
    intent_loop_function = _extract_argmax_and_embed(
        intent_embedding,
        DNN_at_output,
        target1_output_projection,
        forward_only=use_predicted_output)  #if use_predicted_output else None
    tagging_loop_function = _extract_argmax_and_embed(
        tag_embedding,
        DNN_at_output,
        target2_output_projection,
        forward_only=use_predicted_output)
    intent_targets = [
        array_ops.reshape(math_ops.to_int64(x), [-1]) for x in targets1
    ]
    intent_target_embeddings = list()
    intent_target_embeddings = [
        embedding_ops.embedding_lookup(intent_embedding, target)
        for target in intent_targets
    ]
    tag_targets = [
        array_ops.reshape(math_ops.to_int64(x), [-1]) for x in targets2
    ]
    tag_target_embeddings = list()
    tag_target_embeddings = [
        embedding_ops.embedding_lookup(tag_embedding, target)
        for target in tag_targets
    ]

    if inputs[0].get_shape().ndims != 1:
        (fixed_batch_size, input_size) = inputs[0].get_shape().with_rank(2)
        if input_size.value is None:
            raise ValueError(
                "Input size (second dimension of inputs[0]) must be accessible via "
                "shape inference, but saw value None.")
    else:
        fixed_batch_size = inputs[0].get_shape().with_rank_at_least(1)[0]

    if fixed_batch_size.value:
        batch_size = fixed_batch_size.value
    else:
        batch_size = array_ops.shape(inputs[0])[0]

    state = cell.zero_state(batch_size, dtype)
    zero_output = array_ops.zeros(
        array_ops.stack([batch_size, cell.output_size]), inputs[0].dtype)
    zero_output.set_shape(
        tensor_shape.TensorShape([fixed_batch_size.value, cell.output_size]))

    if sequence_length is not None:  # Prepare variables
        sequence_length = math_ops.to_int32(sequence_length)
        min_sequence_length = math_ops.reduce_min(sequence_length)
        max_sequence_length = math_ops.reduce_max(sequence_length)

#  prev_cell_output = zero_output
    zero_intent_embedding = array_ops.zeros(
        array_ops.stack([batch_size, target1_emb_size]), inputs[0].dtype)
    zero_intent_embedding.set_shape(
        tensor_shape.TensorShape([fixed_batch_size.value, target1_emb_size]))
    zero_tag_embedding = array_ops.zeros(
        array_ops.stack([batch_size, target2_emb_size]), inputs[0].dtype)
    zero_tag_embedding.set_shape(
        tensor_shape.TensorShape([fixed_batch_size.value, target2_emb_size]))

    encoder_outputs = list()
    intent_logits = list()
    tagging_logits = list()
    sampled_intent_embeddings = list()
    sampled_tag_embeddings = list()

    for time, input_ in enumerate(inputs):
        # Bing: introduce output label embeddings as addtional input
        # if feed_previous (during testing):
        #     Use loop_function
        # if NOT feed_previous (during training):
        #     Use true target embedding
        if time == 0:
            current_intent_embedding = zero_intent_embedding
            current_tag_embedding = zero_tag_embedding

        if time > 0: variable_scope.get_variable_scope().reuse_variables()

        # here we introduce a max(0, t-4)/sequence_length intent weight
        thres = zero_intent_thres
        if time <= thres:
            intent_contribution = math_ops.to_float(0)
        else:
            intent_contribution = tf.div(math_ops.to_float(time - thres),
                                         math_ops.to_float(sequence_length))


#      intent_contribution = math_ops.to_float(1)

        x = rnn_cell._linear([
            tf.transpose(
                tf.transpose(current_intent_embedding) * intent_contribution),
            current_tag_embedding, input_
        ], word_emb_size, True)
        call_cell = lambda: cell(x, state)

        # pylint: enable=cell-var-from-loop
        if sequence_length is not None:
            (output_fw,
             state) = rnn._rnn_step(time, sequence_length, min_sequence_length,
                                    max_sequence_length, zero_output, state,
                                    call_cell, cell.state_size)
        else:
            (output_fw, state) = call_cell()

        encoder_outputs.append(output_fw)

        if use_predicted_output:
            intent_logit, current_intent_embedding = intent_loop_function(
                output_fw, time)
            tagging_logit, current_tag_embedding = tagging_loop_function(
                output_fw, time)
        else:
            if train_with_true_label is True:
                intent_logit = multilayer_perceptron_with_initialized_W(
                    output_fw,
                    target1_output_projection,
                    forward_only=use_predicted_output)
                tagging_logit = multilayer_perceptron_with_initialized_W(
                    output_fw,
                    target2_output_projection,
                    forward_only=use_predicted_output)
                current_intent_embedding = intent_target_embeddings[time]
                current_tag_embedding = tag_target_embeddings[time]
            else:
                intent_logit, current_intent_embedding = intent_loop_function(
                    output_fw, time)
                tagging_logit, current_tag_embedding = tagging_loop_function(
                    output_fw, time)
            # prev_symbols.append(prev_symbol)
        if time == 0:
            current_intent_embedding = zero_intent_embedding
            current_tag_embedding = zero_tag_embedding
        sampled_intent_embeddings.append(current_intent_embedding)
        sampled_tag_embeddings.append(current_tag_embedding)

        intent_logits.append(intent_logit)
        tagging_logits.append(tagging_logit)

    return encoder_outputs, state, intent_logits, tagging_logits, sampled_intent_embeddings, sampled_tag_embeddings
Пример #13
0
def attention_decoder(decoder_inputs, sequence_length, initial_state, attention_matrix, cell,
                      output_size=None, loop_function=None,
                      dtype=dtypes.float32, scope=None,
                      initial_state_attention=False):
    if not decoder_inputs:
        raise ValueError("Must provide at least 1 input to attention decoder.")
    if not attention_matrix.get_shape()[1:].is_fully_defined():
        raise ValueError("Shape of attention matrix must be known: %s" % attention_matrix.get_shape())
    if output_size is None:
        output_size = cell.output_size

    with variable_scope.variable_scope(scope or "attention_decoder"):
        #batch_size = array_ops.shape(decoder_inputs[0])[0]  # Needed for reshaping.
        # Temporarily avoid EmbeddingWrapper and seq2seq badness
        # TODO(lukaszkaiser): remove EmbeddingWrapper
        if decoder_inputs[0].get_shape().ndims != 1:
            (fixed_batch_size, input_size) = decoder_inputs[0].get_shape().with_rank(2)
            if input_size.value is None:
                raise ValueError(
                    "Input size (second dimension of inputs[0]) must be accessible via "
                    "shape inference, but saw value None.")
        else:
            fixed_batch_size = decoder_inputs[0].get_shape().with_rank_at_least(1)[0]

        if fixed_batch_size.value:
            batch_size = fixed_batch_size.value
        else:
            batch_size = array_ops.shape(decoder_inputs[0])[0]

        if sequence_length is not None:
            sequence_length = math_ops.to_int32(sequence_length)
            zero_output = array_ops.zeros(tf.stack([batch_size, cell.output_size]), decoder_inputs[0].dtype)
            zero_output.set_shape(tensor_shape.TensorShape([fixed_batch_size.value, cell.output_size]))
            min_sequence_length = math_ops.reduce_min(sequence_length)
            max_sequence_length = math_ops.reduce_max(sequence_length)

        # ATTENTION COMPUTATION
        
        attn_size = attention_matrix.get_shape()[-1].value
        batch_attn_size = tf.stack([batch_size, attn_size])

        def _attention_dot(query, states):
            """Put attention masks on hidden using hidden_features and query."""
            attn_length = states.get_shape()[1].value

            hidden = array_ops.reshape(states, [-1, attn_length, 1, attn_size])
            y = _linear(query, attn_size, True)

            # dot product to produce the attention over incoming states
            s = tf.reduce_sum(tf.multiply(states, tf.expand_dims(y, 1)), -1)
            a = nn_ops.softmax(s)

            # Now calculate the attention-weighted vector d.
            d = math_ops.reduce_sum(array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
            d = array_ops.reshape(d, [-1, attn_size])
            return d

        def _attention_concat(query, states):
            """Put attention masks on hidden using hidden_features and query."""
            v = variable_scope.get_variable("AttnV", [attn_size])
            k = variable_scope.get_variable("AttnW", [1, 1, attn_size, attn_size])

            # attn is v^T * tanh(W1*h_t + U*q)
            
            # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
            attn_length = states.get_shape()[1].value
            hidden = array_ops.reshape(states, [-1, attn_length, 1, attn_size])
            hidden_features = nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")

            y = _linear(query, attn_size, True)
            y = array_ops.reshape(y, [-1, 1, 1, attn_size])
            # Attention mask is a softmax of v^T * tanh(...).
            s = math_ops.reduce_sum(v * math_ops.tanh(hidden_features + y), [2, 3])
            a = nn_ops.softmax(s)
            # Now calculate the attention-weighted vector d.
            d = math_ops.reduce_sum(array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
            d = array_ops.reshape(d, [-1, attn_size])
            return d

        _attention = _attention_dot

        def attention(query):
            if nest.is_sequence(query):
                query_list = nest.flatten(query)
                for q in query_list:
                    ndims = q.get_shape().ndims
                    if ndims:
                        assert ndims == 2
                query = array_ops.concat(query_list, 1)

            outer_states = tf.unstack(attention_matrix, axis=1)

            inner_states = []
            for i, states in enumerate(outer_states):
                with variable_scope.variable_scope("Attention_outer", reuse=i>0):
                    inner_states.append(_attention(query, states))

            with variable_scope.variable_scope("Attention_inner"):
                return _attention(query, tf.stack(inner_states, 1))

        state = cell.zero_state(batch_size, dtype) if initial_state == None else initial_state
        outputs = []
        prev = None

        attns = array_ops.zeros(batch_attn_size, dtype=dtype)
        attns.set_shape([None, attn_size])
        
        if initial_state_attention:
            attns = attention(initial_state)
        for i, inp in enumerate(decoder_inputs):
            if i > 0:
                variable_scope.get_variable_scope().reuse_variables()
            # If loop_function is set, we use it instead of decoder_inputs.
            if loop_function is not None and prev is not None:
                with variable_scope.variable_scope("loop_function", reuse=True):
                    inp = loop_function(prev, i)
            # Merge input and previous attentions into one vector of the right size.
            input_size = inp.get_shape().with_rank(2)[1]
            if input_size.value is None:
                raise ValueError("Could not infer input size from input: %s" % inp.name)
            x = _linear([inp] + [attns], input_size, True)

            if sequence_length is not None:
                call_cell = lambda: cell(x, state)
                if sequence_length is not None:
                    cell_output, state = _rnn_step(
                      i, sequence_length, min_sequence_length, max_sequence_length, zero_output, state, call_cell, cell.state_size)
            else:
                cell_output, state = cell(x, state)


            # Run the attention mechanism.
            if i == 0 and initial_state_attention:
                with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True):
                    attns = attention(state)
            else:
                attns = attention(state)

            with variable_scope.variable_scope("AttnOutputProjection"):
                output = _linear([cell_output] + [attns], output_size, True)
            if loop_function is not None:
                prev = output
            outputs.append(output)

    return outputs, state
Пример #14
0
def custom_rnn(cell,
               inputs,
               initial_state=None,
               dtype=None,
               sequence_length=None,
               scope=None):
    """Creates a recurrent neural network specified by RNNCell "cell".

  The simplest form of RNN network generated is:
    state = cell.zero_state(...)
    outputs = []
    for input_ in inputs:
      output, state = cell(input_, state)
      outputs.append(output)
    return (outputs, state)

  However, a few other options are available:

  An initial state can be provided.
  If the sequence_length vector is provided, dynamic calculation is performed.
  This method of calculation does not compute the RNN steps past the maximum
  sequence length of the minibatch (thus saving computational time),
  and properly propagates the state at an example's sequence length
  to the final state output.

  The dynamic calculation performed is, at time t for batch row b,
    (output, state)(b, t) =
      (t >= sequence_length(b))
        ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
        : cell(input(b, t), state(b, t - 1))

  Args:
    cell: An instance of RNNCell.
    inputs: A length T list of inputs, each a tensor of shape
      [batch_size, cell.input_size].
    initial_state: (optional) An initial state for the RNN.  This must be
      a tensor of appropriate type and shape [batch_size x cell.state_size].
    dtype: (optional) The data type for the initial state.  Required if
      initial_state is not provided.
    sequence_length: Specifies the length of each sequence in inputs.
      An int32 or int64 vector (tensor) size [batch_size].  Values in [0, T).
    scope: VariableScope for the created subgraph; defaults to "RNN".

  Returns:
    A pair (outputs, state) where:
      outputs is a length T list of outputs (one for each input)
      state is the final state

  Raises:
    TypeError: If "cell" is not an instance of RNNCell.
    ValueError: If inputs is None or an empty list.
  """

    if not isinstance(cell, tf.nn.rnn_cell.RNNCell):
        raise TypeError("cell must be an instance of RNNCell")
    if not isinstance(inputs, list):
        raise TypeError("inputs must be a list")
    if not inputs:
        raise ValueError("inputs must not be empty")

    outputs = []
    # Create a new scope in which the caching device is either
    # determined by the parent scope, or is set to place the cached
    # Variable using the same placement as for the rest of the RNN.
    with tf.variable_scope(scope or "RNN") as varscope:
        fixed_batch_size = inputs[0].get_shape().with_rank_at_least(1)[0]
        if fixed_batch_size.value:
            batch_size = fixed_batch_size.value
        else:
            batch_size = array_ops.shape(inputs[0])[0]
        if initial_state is not None:
            state = initial_state
        else:
            if not dtype:
                raise ValueError(
                    "If no initial_state is provided, dtype must be.")
            state = cell.zero_state(batch_size, dtype)

        if sequence_length is not None:
            sequence_length = tf.to_int32(sequence_length)

        if sequence_length:  # Prepare variables
            zero_output = tf.zeros(tf.pack([batch_size, cell.output_size]),
                                   inputs[0].dtype)
            zero_output.set_shape(
                tf.TensorShape([fixed_batch_size.value, cell.output_size]))
            min_sequence_length = tf.reduce_min(sequence_length)
            max_sequence_length = tf.reduce_max(sequence_length)

        c1h1 = []

        for time, input_ in enumerate(inputs):
            if time > 0: tf.get_variable_scope().reuse_variables()
            # pylint: disable=cell-var-from-loop
            call_cell = lambda: cell(input_, state)
            # pylint: enable=cell-var-from-loop
            if sequence_length:
                (output, state) = _rnn_step(time, sequence_length,
                                            min_sequence_length,
                                            max_sequence_length, zero_output,
                                            state, call_cell)
            else:
                (output, state) = call_cell()
            if time == 0:
                c1h1 = state
            outputs.append(output)

        return (outputs, state, c1h1)
Пример #15
0
def true_bptt_rnn(cell,
                  inputs,
                  initial_state=None,
                  dtype=None,
                  sequence_length=None,
                  scope=None,
                  state_index=1):  # Adapted From Tensorflow
    """
    Creates a recurrent neural network specified by RNNCell `cell`.
    The simplest form of RNN network generated is:

    .. code:: python

        state = cell.zero_state(...)
        outputs = []
        for input_ in inputs:
            output, state = cell(input_, state)
            outputs.append(output)
      return (outputs, state)

    However, a few other options are available:
    An initial state can be provided.
    If the sequence_length vector is provided, dynamic calculation is performed.
    This method of calculation does not compute the RNN steps past the maximum
    sequence length of the minibatch (thus saving computational time),
    and properly propagates the state at an example's sequence length
    to the final state output.
    The dynamic calculation performed is, at time t for batch row b,

    .. code ::

        (output, state)(b, t) = (t >= sequence_length(b)) ? (zeros(cell.output_size), states(b, sequence_length(b) - 1)) : cell(input(b, t), state(b, t - 1))

    :param cell: An instance of RNNCell.
    :param inputs: A length T list of inputs, each a tensor of shape
                   [batch_size, input_size].
    :param initial_state: (optional) An initial state for the RNN.
        If `cell.state_size` is an integer, this must be
        a tensor of appropriate type and shape `[batch_size x cell.state_size]`.
        If `cell.state_size` is a tuple, this should be a tuple of
        tensors having shapes `[batch_size, s] for s in cell.state_size`.
    :param dtype: (optional) The data type for the initial state.  Required if
        initial_state is not provided.
    :param sequence_length: Specifies the length of each sequence in inputs.
        An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
    :param scope: VariableScope for the created subgraph; defaults to "RNN".
    :param state_index: (int) If -1 final state is returned, if 1 state after first rnn step is returned. If anything else
                        all states are returned
    :return: A pair (outputs, state) where:

        - outputs is a length T list of outputs (one for each input)
        - state is the final state or a a length T list of cell states
    :raise: TypeError: If `cell` is not an instance of RNNCell.
      ValueError: If `inputs` is `None` or an empty list, or if the input depth
        (column size) cannot be inferred from inputs via shape inference.
    """

    if not isinstance(cell, rnn_cell.RNNCell):
        raise TypeError("cell must be an instance of RNNCell")
    if not isinstance(inputs, list):
        raise TypeError("inputs must be a list")
    if not inputs:
        raise ValueError("inputs must not be empty")

    outputs = []
    states = []
    # Create a new scope in which the caching device is either
    # determined by the parent scope, or is set to place the cached
    # Variable using the same placement as for the rest of the RNN.
    with vs.variable_scope(scope or "RNN") as varscope:
        if varscope.caching_device is None:
            varscope.set_caching_device(lambda op: op.device)

        # Temporarily avoid EmbeddingWrapper and seq2seq badness
        if inputs[0].get_shape().ndims != 1:
            input_shape = inputs[0].get_shape().with_rank_at_least(2)
            input_shape[1:].assert_is_fully_defined()
            (fixed_batch_size, input_size) = input_shape[0], input_shape[1:]
            if input_size[0].value is None:
                raise ValueError(
                    "Input size (second dimension of inputs[0]) must be accessible via "
                    "shape inference, but saw value None.")
        else:
            fixed_batch_size = inputs[0].get_shape().with_rank_at_least(1)[0]

        if fixed_batch_size.value:
            batch_size = fixed_batch_size.value
        else:
            batch_size = array_ops.shape(inputs[0])[0]
        if initial_state is not None:
            state = initial_state
        else:
            if not dtype:
                raise ValueError("If no initial_state is provided, "
                                 "dtype must be specified")
            state = cell.zero_state(batch_size, dtype)

        if sequence_length is not None:  # Prepare variables
            sequence_length = math_ops.to_int32(sequence_length)
            # convert int to TensorShape if necessary
            output_size = _state_size_with_prefix(cell.output_size,
                                                  prefix=[batch_size])
            zero_output = array_ops.zeros(array_ops.pack(output_size),
                                          inputs[0].dtype)
            zero_output_shape = _state_size_with_prefix(
                cell.output_size, prefix=[fixed_batch_size.value])
            zero_output.set_shape(tensor_shape.TensorShape(zero_output_shape))
            min_sequence_length = math_ops.reduce_min(sequence_length)
            max_sequence_length = math_ops.reduce_max(sequence_length)

        for time, input_ in enumerate(inputs):
            if time > 0:
                varscope.reuse_variables()
            # pylint: disable=cell-var-from-loop
            call_cell = lambda: cell(input_, state)
            # pylint: enable=cell-var-from-loop
            if sequence_length is not None:
                (output, state) = rnn._rnn_step(
                    time=time,
                    sequence_length=sequence_length,
                    min_sequence_length=min_sequence_length,
                    max_sequence_length=max_sequence_length,
                    zero_output=zero_output,
                    state=state,
                    call_cell=call_cell,
                    state_size=cell.state_size)
            else:
                (output, state) = call_cell()
            states.append(state)
            outputs.append(output)
        if state_index == 1:
            next_state = states[1]
        elif state_index == -1:
            next_state = states[-1]
        else:
            next_state = states
    return outputs, next_state
Пример #16
0
def _rnn(cell,
               inputs,
               initial_state=None,
               dtype=None,
               sequence_length=None,
               scope=None):
  """Creates a recurrent neural network specified by RNNCell `cell`.

  The simplest form of RNN network generated is:

  ```python
    state = cell.zero_state(...)
    outputs = []
    for input_ in inputs:
      output, state = cell(input_, state)
      outputs.append(output)
    return (outputs, state)
  ```
  However, a few other options are available:

  An initial state can be provided.
  If the sequence_length vector is provided, dynamic calculation is performed.
  This method of calculation does not compute the RNN steps past the maximum
  sequence length of the minibatch (thus saving computational time),
  and properly propagates the state at an example's sequence length
  to the final state output.

  The dynamic calculation performed is, at time `t` for batch row `b`,

  ```python
    (output, state)(b, t) =
      (t >= sequence_length(b))
        ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
        : cell(input(b, t), state(b, t - 1))
  ```

  Args:
    cell: An instance of RNNCell.
    inputs: A length T list of inputs, each a `Tensor` of shape
      `[batch_size, input_size]`, or a nested tuple of such elements.
    initial_state: (optional) An initial state for the RNN.
      If `cell.state_size` is an integer, this must be
      a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
      If `cell.state_size` is a tuple, this should be a tuple of
      tensors having shapes `[batch_size, s] for s in cell.state_size`.
    dtype: (optional) The data type for the initial state and expected output.
      Required if initial_state is not provided or RNN state has a heterogeneous
      dtype.
    sequence_length: Specifies the length of each sequence in inputs.
      An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
    scope: VariableScope for the created subgraph; defaults to "rnn".

  Returns:
    A pair (outputs, state) where:

    - outputs is a length T list of outputs (one for each input), or a nested
      tuple of such elements.
    - state is the final state

  Raises:
    TypeError: If `cell` is not an instance of RNNCell.
    ValueError: If `inputs` is `None` or an empty list, or if the input depth
      (column size) cannot be inferred from inputs via shape inference.
  """

  if not _like_rnncell(cell):
    raise TypeError("cell must be an instance of RNNCell")
  if not nest.is_sequence(inputs):
    raise TypeError("inputs must be a sequence")
  if not inputs:
    raise ValueError("inputs must not be empty")

  outputs = []
  states  = []
  # Create a new scope in which the caching device is either
  # determined by the parent scope, or is set to place the cached
  # Variable using the same placement as for the rest of the RNN.
  with vs.variable_scope(scope or "rnn") as varscope:
    if varscope.caching_device is None:
      varscope.set_caching_device(lambda op: op.device)

    # Obtain the first sequence of the input
    first_input = inputs
    while nest.is_sequence(first_input):
      first_input = first_input[0]

    # Temporarily avoid EmbeddingWrapper and seq2seq badness
    # TODO(lukaszkaiser): remove EmbeddingWrapper
    if first_input.get_shape().ndims != 1:

      input_shape = first_input.get_shape().with_rank_at_least(2)
      fixed_batch_size = input_shape[0]

      flat_inputs = nest.flatten(inputs)
      for flat_input in flat_inputs:
        input_shape = flat_input.get_shape().with_rank_at_least(2)
        batch_size, input_size = input_shape[0], input_shape[1:]
        fixed_batch_size.merge_with(batch_size)
        for i, size in enumerate(input_size):
          if size.value is None:
            raise ValueError(
                "Input size (dimension %d of inputs) must be accessible via "
                "shape inference, but saw value None." % i)
    else:
      fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]

    if fixed_batch_size.value:
      batch_size = fixed_batch_size.value
    else:
      batch_size = array_ops.shape(first_input)[0]
    if initial_state is not None:
      state = initial_state
    else:
      if not dtype:
        raise ValueError("If no initial_state is provided, "
                         "dtype must be specified")
      state = cell.zero_state(batch_size, dtype)

    if sequence_length is not None:  # Prepare variables
      sequence_length = ops.convert_to_tensor(
          sequence_length, name="sequence_length")
      if sequence_length.get_shape().ndims not in (None, 1):
        raise ValueError(
            "sequence_length must be a vector of length batch_size")

      def _create_zero_output(output_size):
        # convert int to TensorShape if necessary
        size = _concat(batch_size, output_size)
        output = array_ops.zeros(
            array_ops.stack(size), _infer_state_dtype(dtype, state))
        shape = _concat(fixed_batch_size.value, output_size, static=True)
        output.set_shape(tensor_shape.TensorShape(shape))
        return output

      output_size = cell.output_size
      flat_output_size = nest.flatten(output_size)
      flat_zero_output = tuple(
          _create_zero_output(size) for size in flat_output_size)
      zero_output = nest.pack_sequence_as(
          structure=output_size, flat_sequence=flat_zero_output)

      sequence_length = math_ops.to_int32(sequence_length)
      min_sequence_length = math_ops.reduce_min(sequence_length)
      max_sequence_length = math_ops.reduce_max(sequence_length)

    for time, input_ in enumerate(inputs):
      if time > 0:
        varscope.reuse_variables()
      # pylint: disable=cell-var-from-loop
      call_cell = lambda: cell(input_, state)
      # pylint: enable=cell-var-from-loop
      if sequence_length is not None:
        (output, state) = _rnn_step(
            time=time,
            sequence_length=sequence_length,
            min_sequence_length=min_sequence_length,
            max_sequence_length=max_sequence_length,
            zero_output=zero_output,
            state=state,
            call_cell=call_cell,
            state_size=cell.state_size)
      else:
        (output, state) = call_cell()

      outputs.append(output)
      states.append(state)
      
    return (outputs, states)
Пример #17
0
    def _call(self, inputs):
        # inputs = self._preprocess_input(inputs)

        actions = []
        outputs = []

        # Create a new scope in which the caching device is either
        # determined by the parent scope, or is set to place the cached
        # Variable using the same placement as for the rest of the RNN.
        with tf.variable_scope(self.name or "AttentionRNN") as varscope:
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

            # Obtain the first sequence of the input
            first_input = inputs
            while nest.is_sequence(first_input):
                first_input = first_input[0]

            # Temporarily avoid EmbeddingWrapper and seq2seq badness
            # TODO(lukaszkaiser): remove EmbeddingWrapper
            if first_input.get_shape().ndims != 1:

                input_shape = first_input.get_shape().with_rank_at_least(2)
                fixed_batch_size = input_shape[0]

                flat_inputs = nest.flatten(inputs)
                for flat_input in flat_inputs:
                    input_shape = flat_input.get_shape().with_rank_at_least(2)
                    batch_size, input_size = input_shape[0], input_shape[1:]
                    fixed_batch_size.merge_with(batch_size)
                    for i, size in enumerate(input_size):
                        if size.value is None:
                            raise ValueError(
                                "Input size (dimension %d of inputs) must be accessible via "
                                "shape inference, but saw value None." % i)
            else:
                fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]

            if fixed_batch_size.value:
                batch_size = fixed_batch_size.value
            else:
                batch_size = tf.shape(inputs)[0]
            if self._initial_state is not None:
                state = self._initial_state
            else:
                state = self._cell.zero_state(batch_size, inputs.dtype)

            def _create_zero_output(output_size_):
                # convert int to TensorShape if necessary
                size_ = [batch_size] + tolist(output_size_)
                output_ = tf.zeros(tf.pack(size_), inputs.dtype)
                shape = [fixed_batch_size.value] + tolist(output_size_)
                output_.set_shape(tf.TensorShape(shape))
                return output_

            flat_output_size = nest.flatten(self._output_size)
            flat_zero_output = tuple(
                _create_zero_output(size) for size in flat_output_size)
            zero_output = nest.pack_sequence_as(structure=self._output_size,
                                                flat_sequence=flat_zero_output)

            inputs_ = tolist(inputs)

            sequence_length = self._sequence_length
            if sequence_length is None or len(inputs_) > 1:
                sequence_length = len(inputs_)

            if len(inputs_) == 1:
                max_val = sequence_length
                if isinstance(sequence_length, (list, tuple)):
                    max_val = max(sequence_length)
                inputs_ = inputs_ * max_val

            sequence_length = tf.to_int32(sequence_length)
            if sequence_length.get_shape().ndims == 0:
                sequence_length = tf.expand_dims(sequence_length, 0)
                tf.tile(sequence_length, [batch_size])

            min_sequence_length = tf.reduce_min(sequence_length)
            max_sequence_length = tf.reduce_max(sequence_length)

            # sample an initial starting action by forwarding zeros
            # through the action
            output = tf.zeros((batch_size, self._output_size))

            for step, in_ in enumerate(inputs_):
                if step > 0:
                    varscope.reuse_variables()

                action = self._action_model(output)
                actions.append(action)

                input_ = self._input_model([action, in_, action])

                call_cell = lambda: self._cell(input_, state)
                (output, state) = _rnn_step(
                    time=step,
                    sequence_length=sequence_length,
                    min_sequence_length=min_sequence_length,
                    max_sequence_length=max_sequence_length,
                    zero_output=zero_output,
                    state=state,
                    call_cell=call_cell,
                    state_size=self._cell.state_size)

                outputs.append(output)

            outputs_ = tf.pack(outputs)
            axes = [1, 0] + list(range(2, len(outputs_.get_shape())))  # (batch_size, sequence_length, input0, ...)
            outputs_ = tf.transpose(outputs_, axes)

            if self._return_sequences:
                return outputs_
            return outputs[-1]