Exemple #1
0
 def _build_rnn_inputs(self, word_embedder, time_major):
     inputs = word_embedder.embed(self.inputs, zero_pad=True)
     self.word_embeddings = inputs
     if not time_major:
         inputs = transpose_first_two_dims(
             inputs)  # (seq_len, batch_size, input_size)
     return inputs
Exemple #2
0
 def _build_output(self, output_dict):
     '''
     Take RNN outputs and produce logits over the vocab and the attentions.
     '''
     logits = super(CopyGraphDecoder, self)._build_output(output_dict)  # (batch_size, seq_len, num_symbols)
     attn_scores = transpose_first_two_dims(output_dict['attn_scores'])  # (batch_size, seq_len, num_nodes)
     return tf.concat(2, [logits, attn_scores])
Exemple #3
0
 def _build_output(self, output_dict):
     '''
     Take RNN outputs and produce logits over the vocab.
     '''
     outputs = output_dict['outputs']
     outputs = transpose_first_two_dims(outputs)  # (batch_size, seq_len, output_size)
     logits = batch_linear(outputs, self.num_symbols, True)
     #logits = BasicDecoder.penalize_repetition(logits)
     return logits
Exemple #4
0
    def _build_rnn_inputs(self, word_embedder, time_major):
        inputs = super(GraphDecoder, self)._build_rnn_inputs(word_embedder, time_major)

        checklists = tf.cumsum(tf.one_hot(self.entities, self.num_nodes, on_value=1, off_value=0), axis=1) + self.init_checklists
        # cumsum can cause >1 indicator
        checklists = tf.cast(tf.greater(checklists, 0), tf.float32)
        self.output_dict['checklists'] = checklists

        checklists = transpose_first_two_dims(checklists)  # (seq_len, batch_size, num_nodes)
        return inputs, checklists
Exemple #5
0
 def _build_rnn_inputs(self, word_embedder, time_major):
     '''
     Concatenate word embedding with entity/node embedding.
     '''
     word_embeddings = word_embedder.embed(self.inputs, zero_pad=True)
     self.word_embeddings = word_embeddings
     if self.node_embed_in_rnn_inputs:
         entity_embeddings = self._get_node_embedding(self.context[0], self.entities)
         inputs = tf.concat(2, [word_embeddings, entity_embeddings])
     else:
         inputs = word_embeddings
     if not time_major:
         inputs = transpose_first_two_dims(inputs)  # (seq_len, batch_size, input_size)
     return inputs
Exemple #6
0
 def _get_final_state(self, states):
     '''
     Return the final non-pad state from tf.scan outputs.
     '''
     with tf.name_scope(type(self).__name__+'/get_final_state'):
         flat_states = nest.flatten(states)
         flat_last_states = []
         for state in flat_states:
             state = transpose_first_two_dims(state)  # (batch_size, time_seq, state_size)
             # NOTE: when state has dim=4, it's the context which does not change in a seq; just take the last one.
             if len(state.get_shape()) == 4:
                 last_state = state[:, -1, :, :]
             else:
                 last_state = tf.squeeze(batch_embedding_lookup(state, tf.reshape(self.last_inds, [-1, 1])), [1])
             flat_last_states.append(last_state)
         last_states = nest.pack_sequence_as(states, flat_last_states)
     return last_states