예제 #1
0
class AttentionDecoder(RNNDecoder):
    """An RNN Decoder that uses attention over an input sequence.

  Args:
    cell: An instance of ` tf.contrib.rnn.RNNCell`
    helper: An instance of `tf.contrib.seq2seq.Helper` to assist decoding
    initial_state: A tensor or tuple of tensors used as the initial cell
      state.
    vocab_size: Output vocabulary size, i.e. number of units
      in the softmax layer
    attention_keys: The sequence used to calculate attention scores.
      A tensor of shape `[B, T, ...]`.
    attention_values: The sequence to attend over.
      A tensor of shape `[B, T, input_dim]`.
    attention_values_length: Sequence length of the attention values.
      An int32 Tensor of shape `[B]`.
    attention_fn: The attention function to use. This function map from
      `(state, inputs)` to `(attention_scores, attention_context)`.
      For an example, see `seq2seq.decoder.attention.AttentionLayer`.
    reverse_scores: Optional, an array of sequence length. If set,
      reverse the attention scores in the output. This is used for when
      a reversed source sequence is fed as an input but you want to
      return the scores in non-reversed order.
  """
    def __init__(self,
                 params,
                 mode,
                 vocab_size,
                 attention_keys,
                 attention_values,
                 attention_values_length,
                 attention_fn,
                 reverse_scores_lengths=None,
                 name="attention_decoder"):
        super(AttentionDecoder, self).__init__(params, mode, name)
        self.vocab_size = vocab_size
        self.attention_keys = attention_keys
        self.attention_values = attention_values
        self.attention_values_length = attention_values_length
        self.attention_fn = attention_fn
        self.reverse_scores_lengths = reverse_scores_lengths

    @property
    def output_size(self):
        return AttentionDecoderOutput(
            logits=self.vocab_size,
            predicted_ids=tf.TensorShape([]),
            cell_output=self.cell.output_size,
            attention_scores=tf.shape(self.attention_values)[1:-1],
            attention_context=self.attention_values.get_shape()[-1])

    @property
    def output_dtype(self):
        return AttentionDecoderOutput(logits=tf.float32,
                                      predicted_ids=tf.int32,
                                      cell_output=tf.float32,
                                      attention_scores=tf.float32,
                                      attention_context=tf.float32)

    def initialize(self, name=None):
        finished, first_inputs = self.helper.initialize()

        # Concat empty attention context
        attention_context = tf.zeros([
            tf.shape(first_inputs)[0],
            self.attention_values.get_shape().as_list()[-1]
        ])
        print(first_inputs)
        print(attention_context)
        first_inputs = tf.concat([first_inputs, attention_context],
                                 1,
                                 name="first_inputs")
        return finished, first_inputs, self.initial_state

    def compute_output(self, cell_output):
        """Computes the decoder outputs."""

        # Compute attention
        att_scores, attention_context = self.attention_fn(
            query=cell_output,
            keys=self.attention_keys,
            values=self.attention_values,
            values_length=self.attention_values_length)

        # TODO: Make this a parameter: We may or may not want this.
        # Transform attention context.
        # This makes the softmax smaller and allows us to synthesize information
        # between decoder state and attention context
        # see https://arxiv.org/abs/1508.04025v5
        softmax_input = tf.contrib.layers.fully_connected(
            inputs=tf.concat([cell_output, attention_context], 1),
            num_outputs=self.cell.output_size,
            activation_fn=tf.nn.tanh,
            scope="attention_mix")

        # Softmax computation
        logits = tf.contrib.layers.fully_connected(inputs=softmax_input,
                                                   num_outputs=self.vocab_size,
                                                   activation_fn=None,
                                                   scope="logits")

        return softmax_input, logits, att_scores, attention_context

    def _setup(self, initial_state, helper):
        self.initial_state = initial_state

        def att_next_inputs(time, outputs, state, sample_ids, name=None):
            """Wraps the original decoder helper function to append the attention
      context.
      """
            finished, next_inputs, next_state = helper.next_inputs(
                time=time,
                outputs=outputs,
                state=state,
                sample_ids=sample_ids,
                name=name)
            next_inputs = tf.concat([next_inputs, outputs.attention_context],
                                    1)
            return (finished, next_inputs, next_state)

        self.helper = CustomHelper(initialize_fn=helper.initialize,
                                   sample_fn=helper.sample,
                                   next_inputs_fn=att_next_inputs)

    def step(self, time_, inputs, state, name=None):
        cell_output, cell_state = self.cell(inputs, state)
        cell_output_new, logits, attention_scores, attention_context = \
          self.compute_output(cell_output)

        if self.reverse_scores_lengths is not None:
            attention_scores = tf.reverse_sequence(
                input=attention_scores,
                seq_lengths=self.reverse_scores_lengths,
                seq_dim=1,
                batch_dim=0)

        sample_ids = self.helper.sample(time=time_,
                                        outputs=logits,
                                        state=cell_state)

        outputs = AttentionDecoderOutput(logits=logits,
                                         predicted_ids=sample_ids,
                                         cell_output=cell_output_new,
                                         attention_scores=attention_scores,
                                         attention_context=attention_context)

        finished, next_inputs, next_state = self.helper.next_inputs(
            time=time_,
            outputs=outputs,
            state=cell_state,
            sample_ids=sample_ids)

        return (outputs, next_state, next_inputs, finished)
예제 #2
0
class CopyGenDecoder(RNNDecoder):
    def __init__(self,
                 params,
                 mode,
                 vocab_size,
                 attention_keys,
                 attention_values,
                 attention_values_length,
                 attention_fn,
                 source_embedding=None,
                 reverse_scores_lengths=None,
                 name="CopyGenDecoder"):
        super(CopyGenDecoder, self).__init__(params, mode, name)
        self.vocab_size = vocab_size
        self.source_embedding = source_embedding
        self.attention_keys = attention_keys
        self.attention_values = attention_values
        self.attention_values_length = attention_values_length
        self.attention_fn = attention_fn
        self.reverse_scores_lengths = reverse_scores_lengths

    @property
    def output_size(self):
        return CopyGenDecoderOutput(
            logits=self.vocab_size,
            predicted_ids=tf.TensorShape([]),
            cell_output=self.cell.output_size,
            attention_scores=tf.shape(self.attention_values)[1:-1],
            attention_context=self.attention_values.get_shape()[-1],
            pgens=tf.TensorShape([1]))

    @property
    def output_dtype(self):
        return CopyGenDecoderOutput(logits=tf.float32,
                                    predicted_ids=tf.int32,
                                    cell_output=tf.float32,
                                    attention_scores=tf.float32,
                                    attention_context=tf.float32,
                                    pgens=tf.float32)

    def initialize(self, name=None):

        finished, first_inputs = self.helper.initialize()

        # Concat empty attention context
        attention_context = tf.zeros([
            tf.shape(first_inputs)[0],
            self.attention_values.get_shape().as_list()[-1]
        ])
        first_inputs = tf.concat([first_inputs, attention_context], 1)

        return finished, first_inputs, self.initial_state

    def cal_gen_probability(self, out_features):
        #probability of generation
        pgen = tf.contrib.layers.fully_connected(
            inputs=out_features,
            num_outputs=1,
            activation_fn=tf.nn.tanh,
        )
        return pgen

    def compute_output(self, cell_output):
        """Computes the decoder outputs."""

        # Compute attention
        att_scores, attention_context = self.attention_fn(
            query=cell_output,
            keys=self.attention_keys,
            values=self.attention_values,
            values_length=self.attention_values_length)

        # TODO: Make this a parameter: We may or may not want this.
        # Transform attention context.
        # This makes the softmax smaller and allows us to synthesize information
        # between decoder state and attention context
        # see https://arxiv.org/abs/1508.04025v5

        decode_out_features = tf.concat([cell_output, attention_context], 1)
        softmax_input = tf.contrib.layers.fully_connected(
            inputs=decode_out_features,
            num_outputs=self.cell.output_size,
            activation_fn=tf.nn.tanh,
            scope="attention_mix")

        # Softmax computation
        logits = tf.contrib.layers.fully_connected(inputs=softmax_input,
                                                   num_outputs=self.vocab_size,
                                                   activation_fn=None,
                                                   scope="logits")

        #generation probability
        pgens = self.cal_gen_probability(decode_out_features)

        return softmax_input, logits, att_scores, attention_context, pgens

    def _setup(self, initial_state, helper):

        self.initial_state = initial_state

        self.W_u = 0

        # wout_dim = 2 * self.params["decoder.params"]["rnn_cell"]["cell_params"]["num_units"]  # dim(context_vector + hidden states)
        # word_dim = self.source_embedding.shape[1].value
        # with tf.variable_scope("copy_gen_project_wout"):
        #   self.wout_proj = tf.get_variable("project_wout", shape=[word_dim, word_dim], dtype=tf.float32, initializer=tf.truncated_normal_initializer)
        #   self.wout = tf.tanh(tf.matmul(self.source_embedding, self.wout_proj)) #[v,d] * [d,d] = [v,d]

        def att_next_inputs(time, outputs, state, sample_ids, name=None):
            """Wraps the original decoder helper function to append the attention
      context.
      """
            finished, next_inputs, next_state = helper.next_inputs(
                time=time,
                outputs=outputs,
                state=state,
                sample_ids=sample_ids,
                name=name)
            next_inputs = tf.concat([next_inputs, outputs.attention_context],
                                    1)
            return (finished, next_inputs, next_state)

        self.helper = CustomHelper(initialize_fn=helper.initialize,
                                   sample_fn=helper.sample,
                                   next_inputs_fn=att_next_inputs)

    def step(self, time_, inputs, state, name=None):
        cell_output, cell_state = self.cell(inputs, state)
        cell_output_new, logits, attention_scores, attention_context, pgens = \
          self.compute_output(cell_output)

        if self.reverse_scores_lengths is not None:
            attention_scores = tf.reverse_sequence(
                input=attention_scores,
                seq_lengths=self.reverse_scores_lengths,
                seq_dim=1,
                batch_dim=0)

        sample_ids = self.helper.sample(time=time_,
                                        outputs=logits,
                                        state=cell_state)

        outputs = CopyGenDecoderOutput(logits=logits,
                                       predicted_ids=sample_ids,
                                       cell_output=cell_output_new,
                                       attention_scores=attention_scores,
                                       attention_context=attention_context,
                                       pgens=pgens)

        finished, next_inputs, next_state = self.helper.next_inputs(
            time=time_,
            outputs=outputs,
            state=cell_state,
            sample_ids=sample_ids)

        return (outputs, next_state, next_inputs, finished)

    def _build(self, initial_state, helper):
        if not self.initial_state:
            self._setup(initial_state, helper)

        scope = tf.get_variable_scope()
        scope.set_initializer(
            tf.random_uniform_initializer(-self.params["init_scale"],
                                          self.params["init_scale"]))

        maximum_iterations = None
        if self.mode == tf.contrib.learn.ModeKeys.INFER:
            maximum_iterations = self.params["max_decode_length"]

        outputs, final_state = dynamic_decode(
            decoder=self,
            output_time_major=True,
            impute_finished=False,
            maximum_iterations=maximum_iterations)
        return self.finalize(outputs, final_state)
예제 #3
0
class SchemaAttentionDecoder(RNNDecoder):
    """An RNN Decoder that uses attention over an input sequence and a schema.

  Args:
    cell: An instance of ` tf.contrib.rnn.RNNCell`
    helper: An instance of `tf.contrib.seq2seq.Helper` to assist decoding
    initial_state: A tensor or tuple of tensors used as the initial cell
      state.
    vocab_size: Output vocabulary size, i.e. number of units
      in the softmax layer
    attention_keys: The sequence used to calculate attention scores.
      A tensor of shape `[B, T, ...]`.
    attention_values: The sequence to attend over.
      A tensor of shape `[B, T, input_dim]`.
    attention_values_length: Sequence length of the attention values.
      An int32 Tensor of shape `[B]`.
    attention_fn: The attention function to use. This function map from
      `(state, inputs)` to `(attention_scores, attention_context)`.
      For an example, see `seq2seq.decoder.attention.AttentionLayer`.
    reverse_scores: Optional, an array of sequence length. If set,
      reverse the attention scores in the output. This is used for when
      a reversed source sequence is fed as an input but you want to
      return the scores in non-reversed order.
  """

    # the definition of schema_attention_function is in models/schema_attention_seq2seq.py
    def __init__(
            self,
            params,
            mode,
            vocab_size,
            attention_keys,
            attention_values,
            attention_values_length,
            attention_fn,
            # 4 extra values
            reverse_scores_lengths=None,
            schema_attention_keys=None,
            schema_attention_values=None,
            schema_attention_values_length=None,
            schema_attention_fn=None,
            name="schema_attention_decoder"):
        super(SchemaAttentionDecoder, self).__init__(params, mode, name)
        self.vocab_size = vocab_size
        self.attention_keys = attention_keys
        self.attention_values = attention_values
        self.attention_values_length = attention_values_length
        self.attention_fn = attention_fn
        self.reverse_scores_lengths = reverse_scores_lengths
        self.schema_attention_keys = schema_attention_keys
        self.schema_attention_values = schema_attention_values
        self.schema_attention_values_length = schema_attention_values_length
        if schema_attention_fn:
            self.schema_attention_fn = schema_attention_fn
        else:
            self.schema_attention_fn = attention_fn

    @property
    def output_size(self):
        return SchemaAttentionDecoderOutput(
            logits=self.vocab_size,
            predicted_ids=tf.TensorShape([]),
            cell_output=self.cell.output_size,
            attention_scores=tf.shape(self.attention_values)[1:-1],
            attention_context=self.attention_values.get_shape()[-1],
            schema_attention_scores=tf.shape(
                self.schema_attention_values)[1:-1],
            schema_attention_context=self.schema_attention_values.get_shape()
            [-1])

    @property
    def output_dtype(self):
        return SchemaAttentionDecoderOutput(
            logits=tf.float32,
            predicted_ids=tf.int32,
            cell_output=tf.float32,
            attention_scores=tf.float32,
            attention_context=tf.float32,
            schema_attention_scores=tf.float32,
            schema_attention_context=tf.float32)

    def initialize(self, name=None):
        finished, first_inputs = self.helper.initialize()

        # Concat empty attention context
        attention_context = tf.zeros([
            tf.shape(first_inputs)[0],
            self.attention_values.get_shape().as_list()[-1]
        ])
        schema_attention_context = tf.zeros([
            tf.shape(first_inputs)[0],
            self.schema_attention_values.get_shape().as_list()[-1]
        ])
        first_inputs = tf.concat(
            [first_inputs, attention_context, schema_attention_context], 1)

        return finished, first_inputs, self.initial_state

    def compute_output(self, cell_output, calculate_softmax=True):
        """Computes the decoder outputs."""

        # Compute attention
        att_scores, attention_context = self.attention_fn(
            query=cell_output,
            keys=self.attention_keys,
            values=self.attention_values,
            values_length=self.attention_values_length)
        # there is a key and a schema attention value
        # which is key? where to find the schema attention function?
        schema_att_scores, schema_attention_context = self.schema_attention_fn(
            query=cell_output,
            keys=self.schema_attention_keys,
            values=self.schema_attention_values,
            values_length=self.schema_attention_values_length)

        softmax_input = None
        logits = None
        if calculate_softmax:
            softmax_input, logits = self._calculate_softmax(
                [cell_output, attention_context, schema_attention_context])

        return softmax_input, logits, att_scores, attention_context, schema_att_scores, schema_attention_context

    def _calculate_softmax(self, list_of_contexts):
        softmax_input = tf.contrib.layers.fully_connected(
            inputs=tf.concat(list_of_contexts, 1),
            num_outputs=self.cell.output_size,
            activation_fn=tf.nn.tanh,
            scope="attention_mix")

        # Softmax computation
        logits = tf.contrib.layers.fully_connected(inputs=softmax_input,
                                                   num_outputs=self.vocab_size,
                                                   activation_fn=None,
                                                   scope="logits")
        return softmax_input, logits

    def _setup(self, initial_state, helper):
        self.initial_state = initial_state

        def att_next_inputs(time, outputs, state, sample_ids, name=None):
            """Wraps the original decoder helper function to append the attention
      context.
      """
            finished, next_inputs, next_state = helper.next_inputs(
                time=time,
                outputs=outputs,
                state=state,
                sample_ids=sample_ids,
                name=name)
            next_inputs = tf.concat([
                next_inputs, outputs.attention_context,
                outputs.schema_attention_context
            ], 1)
            return (finished, next_inputs, next_state)

        self.helper = CustomHelper(initialize_fn=helper.initialize,
                                   sample_fn=helper.sample,
                                   next_inputs_fn=att_next_inputs)

    def step(self, time_, inputs, state, name=None):
        cell_output, cell_state = self.cell(inputs, state)
        (cell_output_new, logits, attention_scores, attention_context,
         schema_attention_scores, schema_attention_context) = \
        self.compute_output(cell_output)

        if self.reverse_scores_lengths is not None:
            attention_scores = tf.reverse_sequence(
                input=attention_scores,
                seq_lengths=self.reverse_scores_lengths,
                seq_dim=1,
                batch_dim=0)

        sample_ids = self.helper.sample(time=time_,
                                        outputs=logits,
                                        state=cell_state)

        outputs = SchemaAttentionDecoderOutput(
            logits=logits,
            predicted_ids=sample_ids,
            cell_output=cell_output_new,
            attention_scores=attention_scores,
            attention_context=attention_context,
            schema_attention_scores=schema_attention_scores,
            schema_attention_context=schema_attention_context)

        finished, next_inputs, next_state = self.helper.next_inputs(
            time=time_,
            outputs=outputs,
            state=cell_state,
            sample_ids=sample_ids)

        return (outputs, next_state, next_inputs, finished)
예제 #4
0
class AttentionDecoder(RNNDecoder):
  """An RNN Decoder that uses attention over an input sequence.

  Args:
    cell: An instance of ` tf.contrib.rnn.RNNCell`
    helper: An instance of `tf.contrib.seq2seq.Helper` to assist decoding
    initial_state: A tensor or tuple of tensors used as the initial cell
      state.
    vocab_size: Output vocabulary size, i.e. number of units
      in the softmax layer
    attention_keys: The sequence used to calculate attention scores.
      A tensor of shape `[B, T, ...]`.
    attention_values: The sequence to attend over.
      A tensor of shape `[B, T, input_dim]`.
    attention_values_length: Sequence length of the attention values.
      An int32 Tensor of shape `[B]`.
    attention_fn: The attention function to use. This function map from
      `(state, inputs)` to `(attention_scores, attention_context)`.
      For an example, see `seq2seq.decoder.attention.AttentionLayer`.
    reverse_scores: Optional, an array of sequence length. If set,
      reverse the attention scores in the output. This is used for when
      a reversed source sequence is fed as an input but you want to
      return the scores in non-reversed order.
  """

  def __init__(self,
               params,
               mode,
               vocab_size,
               attention_keys,
               attention_values,
               attention_values_length,
               attention_fn,
               reverse_scores_lengths=None,
               name="attention_decoder"):
    super(AttentionDecoder, self).__init__(params, mode, name)
    self.vocab_size = vocab_size
    self.attention_keys = attention_keys
    self.attention_values = attention_values
    self.attention_values_length = attention_values_length
    self.attention_fn = attention_fn
    self.reverse_scores_lengths = reverse_scores_lengths

  @property
  def output_size(self):
    return AttentionDecoderOutput(
        logits=self.vocab_size,
        predicted_ids=tf.TensorShape([]),
        cell_output=self.cell.output_size,
        attention_scores=tf.shape(self.attention_values)[1:-1],
        attention_context=self.attention_values.get_shape()[-1])

  @property
  def output_dtype(self):
    return AttentionDecoderOutput(
        logits=tf.float32,
        predicted_ids=tf.int32,
        cell_output=tf.float32,
        attention_scores=tf.float32,
        attention_context=tf.float32)

  def initialize(self, name=None):
    finished, first_inputs = self.helper.initialize()

    # Concat empty attention context
    attention_context = tf.zeros([
        tf.shape(first_inputs)[0],
        self.attention_values.get_shape().as_list()[-1]
    ])
    first_inputs = tf.concat([first_inputs, attention_context], 1)

    return finished, first_inputs, self.initial_state

  def compute_output(self, cell_output):
    """Computes the decoder outputs."""

    # Compute attention
    att_scores, attention_context = self.attention_fn(
        query=cell_output,
        keys=self.attention_keys,
        values=self.attention_values,
        values_length=self.attention_values_length)

    # TODO: Make this a parameter: We may or may not want this.
    # Transform attention context.
    # This makes the softmax smaller and allows us to synthesize information
    # between decoder state and attention context
    # see https://arxiv.org/abs/1508.04025v5
    softmax_input = tf.contrib.layers.fully_connected(
        inputs=tf.concat([cell_output, attention_context], 1),
        num_outputs=self.cell.output_size,
        activation_fn=tf.nn.tanh,
        scope="attention_mix")

    # Softmax computation
    logits = tf.contrib.layers.fully_connected(
        inputs=softmax_input,
        num_outputs=self.vocab_size,
        activation_fn=None,
        scope="logits")

    return softmax_input, logits, att_scores, attention_context

  def _setup(self, initial_state, helper):
    self.initial_state = initial_state

    def att_next_inputs(time, outputs, state, sample_ids, name=None):
      """Wraps the original decoder helper function to append the attention
      context.
      """
      finished, next_inputs, next_state = helper.next_inputs(
          time=time,
          outputs=outputs,
          state=state,
          sample_ids=sample_ids,
          name=name)
      next_inputs = tf.concat([next_inputs, outputs.attention_context], 1)
      return (finished, next_inputs, next_state)

    self.helper = CustomHelper(
        initialize_fn=helper.initialize,
        sample_fn=helper.sample,
        next_inputs_fn=att_next_inputs)

  def step(self, time_, inputs, state, name=None):
    cell_output, cell_state = self.cell(inputs, state)
    cell_output_new, logits, attention_scores, attention_context = \
      self.compute_output(cell_output)

    if self.reverse_scores_lengths is not None:
      attention_scores = tf.reverse_sequence(
          input=attention_scores,
          seq_lengths=self.reverse_scores_lengths,
          seq_dim=1,
          batch_dim=0)

    sample_ids = self.helper.sample(
        time=time_, outputs=logits, state=cell_state)

    outputs = AttentionDecoderOutput(
        logits=logits,
        predicted_ids=sample_ids,
        cell_output=cell_output_new,
        attention_scores=attention_scores,
        attention_context=attention_context)

    finished, next_inputs, next_state = self.helper.next_inputs(
        time=time_, outputs=outputs, state=cell_state, sample_ids=sample_ids)

    return (outputs, next_state, next_inputs, finished)
예제 #5
0
class SchemaAttentionCopyingDecoder(SchemaAttentionDecoder):
    """
  The version of SchemaAttentionCopyingDecoder that uses
  F(score_n, rowembedding_n, h, c, W) to generate a score for the
  n-th field in the schema.
  """
    def __init__(self,
                 params,
                 mode,
                 vocab_size,
                 attention_keys,
                 attention_values,
                 attention_values_length,
                 attention_fn,
                 reverse_scores_lengths=None,
                 schema_attention_keys=None,
                 schema_attention_values=None,
                 schema_attention_values_length=None,
                 schema_attention_fn=None,
                 name="schema_attention_copying_decoder"):
        super(SchemaAttentionCopyingDecoder, self).__init__(
            params, mode, vocab_size, attention_keys, attention_values,
            attention_values_length, attention_fn, reverse_scores_lengths,
            schema_attention_keys, schema_attention_values,
            schema_attention_values_length, schema_attention_fn, name)
        self.schema_embs = schema_attention_values

    @property
    def output_size(self):
        return SchemaCopyingAttentionDecoderOutput(
            logits=self.vocab_size,
            predicted_ids=tf.TensorShape([]),
            cell_output=self.cell.output_size,
            attention_scores=tf.shape(self.attention_values)[1:-1],
            attention_context=self.attention_values.get_shape()[-1],
            schema_attention_scores=tf.shape(
                self.schema_attention_values)[1:-1],
            schema_attention_context=self.schema_attention_values.get_shape()
            [-1],
            schema_attention_copy_vals=tf.shape(
                self.schema_attention_values)[1:-1])

    @property
    def output_dtype(self):
        return SchemaCopyingAttentionDecoderOutput(
            logits=tf.float32,
            predicted_ids=tf.int32,
            cell_output=tf.float32,
            attention_scores=tf.float32,
            attention_context=tf.float32,
            schema_attention_scores=tf.float32,
            schema_attention_context=tf.float32,
            schema_attention_copy_vals=tf.float32)

    def compute_output(self, cell_output):
        (softmax_input, logits, att_scores, attention_context,
         schema_att_scores,
         schema_attention_context) = super(SchemaAttentionCopyingDecoder,
                                           self).compute_output(cell_output)
        schema_attention_copy_vals = schema_att_scores
        weighted_schema_embs_size = self.cell.output_size + \
                                    self.attention_values.get_shape().as_list()[-1]
        weighted_schema_embs = tf.contrib.layers.fully_connected(
            inputs=self.schema_embs,
            num_outputs=weighted_schema_embs_size,
            activation_fn=None,
            scope="weighted_schema_embs")

        concatenated = tf.expand_dims(tf.concat(
            [cell_output, attention_context], 1),
                                      axis=2)
        schema_attention_copy_vals = schema_att_scores * tf.squeeze(
            tf.matmul(weighted_schema_embs, concatenated), axis=2)

        return softmax_input, logits, att_scores, attention_context, schema_att_scores, schema_attention_context, schema_attention_copy_vals

    def _setup(self, initial_state, helper):
        #TODO: Take advantage of inheritance rather than copy-paste
        self.initial_state = initial_state

        def att_next_inputs(time, outputs, state, sample_ids, name=None):
            """Wraps the original decoder helper function to append the attention
      context.
      """
            finished, next_inputs, next_state = helper.next_inputs(
                time=time,
                outputs=outputs,
                state=state,
                sample_ids=sample_ids,
                name=name)
            next_inputs = tf.concat([
                next_inputs, outputs.attention_context,
                outputs.schema_attention_context
            ], 1)
            return (finished, next_inputs, next_state)

        self.helper = CustomHelper(initialize_fn=helper.initialize,
                                   sample_fn=helper.sample,
                                   next_inputs_fn=att_next_inputs)

    def step(self, time_, inputs, state, name=None):
        cell_output, cell_state = self.cell(inputs, state)
        (cell_output_new, logits, attention_scores, attention_context,
         schema_attention_scores, schema_attention_context,
         schema_attention_copy_vals) = \
        self.compute_output(cell_output)

        if self.reverse_scores_lengths is not None:
            attention_scores = tf.reverse_sequence(
                input=attention_scores,
                seq_lengths=self.reverse_scores_lengths,
                seq_dim=1,
                batch_dim=0)

        sample_ids = self.helper.sample(time=time_,
                                        outputs=logits,
                                        state=cell_state)
        outputs = SchemaCopyingAttentionDecoderOutput(
            logits=logits,
            predicted_ids=sample_ids,
            cell_output=cell_output_new,
            attention_scores=attention_scores,
            attention_context=attention_context,
            schema_attention_scores=schema_attention_scores,
            schema_attention_context=schema_attention_context,
            schema_attention_copy_vals=schema_attention_copy_vals)

        finished, next_inputs, next_state = self.helper.next_inputs(
            time=time_,
            outputs=outputs,
            state=cell_state,
            sample_ids=sample_ids)
        return (outputs, next_state, next_inputs, finished)