Beispiel #1
0
    def beam_search_step(self,
                         input,
                         state,
                         beam_size,
                         attention_construct_fn=None,
                         attention_keys=None,
                         attention_values=None,
                         input_text=None):
        output, state = self.cell(input, state)

        #TODO: this step cause.. attenion decode each step after initalization still need input_text feed
        #will this case attention_keys and attention_values to be recompute(means redo encoding process) each step?
        #can we avoid this? seems not better method,
        #if enocding is slow may be feed attention_keys, attention_values each step
        if attention_construct_fn is not None:
            output = attention_construct_fn(output, attention_keys,
                                            attention_values)

        logits = tf.nn.xw_plus_b(output, self.w, self.v)
        logprobs = tf.nn.log_softmax(logits)

        if input_text is not None:
            logprobs = melt.gather_cols(logprobs, tf.to_int32(input_text))

        top_logprobs, top_ids = tf.nn.top_k(logprobs, beam_size)
        #------too slow... for transfering large data between py and c++ cost a lot!
        #top_logprobs, top_ids = tf.nn.top_k(logprobs, self.vocab_size)

        if input_text is not None:
            top_ids = tf.nn.embedding_lookup(input_text, top_ids)

        return output, state, top_logprobs, top_ids
Beispiel #2
0
    def beam_search_step(self,
                         input,
                         state,
                         cell,
                         beam_size,
                         attention_construct_fn=None,
                         input_text=None):
        output, state = cell(input, state)

        if hasattr(state, 'alignments'):
            tf.add_to_collection('attention_alignments', state.alignments)
            tf.add_to_collection('beam_search_alignments',
                                 tf.get_collection('attention_alignments')[-1])

        #TODO: this step cause.. attenion decode each step after initalization still need input_text feed
        #will this case attention_keys and attention_values to be recompute(means redo encoding process) each step?
        #can we avoid this? seems no better method,
        #if enocding is slow may be feed attention_keys, attention_values each step
        if not FLAGS.decode_use_alignment:
            if FLAGS.gen_only:
                output_fn = self.output_fn
                logits = output_fn(output)
            else:
                indices = melt.batch_values_to_indices(tf.to_int32(input_text))
                batch_size = melt.get_batch_size(input)

                if FLAGS.copy_only:
                    output_fn_ = self.copy_output_fn
                else:
                    output_fn_ = self.gen_copy_output_fn
                output_fn = lambda cell_output, cell_state: output_fn_(
                    indices, batch_size, cell_output, cell_state)

                logits = output_fn(output, state)

            if FLAGS.gen_copy_switch and FLAGS.switch_after_softmax:
                logprobs = tf.log(logits)
            else:
                logprobs = tf.nn.log_softmax(logits)

            if FLAGS.decode_copy:
                logprobs = melt.gather_cols(logprobs, tf.to_int32(input_text))
        else:
            logits = state.alignments
            logits = scores[:, :tf.shape(input_text)[-1]]
            logprobs = tf.nn.log_softmax(logits)

        top_logprobs, top_ids = tf.nn.top_k(logprobs, beam_size)
        #------too slow... for transfering large data between py and c++ cost a lot!
        #top_logprobs, top_ids = tf.nn.top_k(logprobs, self.vocab_size)

        if input_text is not None and FLAGS.decode_copy:
            top_ids = tf.nn.embedding_lookup(input_text, top_ids)

        if hasattr(state, 'cell_state'):
            state = state.cell_state

        return output, state, top_logprobs, top_ids