def build_beam_search_graph(self,
                             beam_size,
                             batch_size,
                             max_decode_length,
                             decode_alpha=0.6):
     self.def_placeholder_and_components()
     past_for_decoder = None
     for i in range(0, self.input_num):
         presents = self.encoder.encode(self.inputs[i], self.input_lens[i],
                                        past_for_decoder)
         if past_for_decoder is None:
             past_for_decoder = presents
         else:
             past_for_decoder = tf.concat([past_for_decoder, presents],
                                          axis=-2)
     with tf.name_scope('beam_search'):
         init_seq = tf.fill(dims=(batch_size, 1), value=self.sos_id)
         seqs, scores = beamsearch.create_inference_graph(
             init_seqs=init_seq,
             state=past_for_decoder,
             step_fn=self.decoder.decode_one_step,
             hparams=self.hparams,
             decode_length=max_decode_length,
             batch_size=batch_size,
             beam_size=beam_size,
             decode_alpha=decode_alpha,
             eos_id=self.eos_id,
             ensemble=False,
             concat_state_dim=None)
     return seqs, scores
Beispiel #2
0
    def ensemble_decoding_beam_search_graph(self,
                                            context_list,
                                            beam_size,
                                            batch_size,
                                            max_decode_length,
                                            eos_id,
                                            model_num,
                                            decode_alpha=0.6):
        def step(hparams, tokens, past=None, scope=None):
            if scope is not None:
                with tf.variable_scope(scope):
                    lm_output = model.model(hparams=hparams,
                                            X=tokens,
                                            past=past,
                                            reuse=tf.AUTO_REUSE)
                    logits = lm_output['logits']
                    presents = lm_output['present']
                    presents.set_shape(
                        model.past_shape(hparams=hparams, batch_size=None))
                    return {
                        'logits': logits,
                        'presents': presents,
                    }

        context_output_list = []
        context_state_list = []
        all_scopes = []
        for i in range(0, model_num):
            with tf.variable_scope('model_' + str(i)) as sc:
                with tf.name_scope('sample_sequence'):
                    context_output_list.append(
                        step(self.hparams, context_list[i][:, :-1], scope=sc))
                    context_state_list.append(
                        context_output_list[-1]['presents'])
            all_scopes.append('model_' + str(i))
        with tf.name_scope('beam_search'):
            init_seq = tf.expand_dims(context_list[0][:, -1], axis=1)
            seqs, scores = beamsearch.create_inference_graph(
                init_seqs=init_seq,
                state=context_state_list,
                step_fn=step,
                hparams=self.hparams,
                decode_length=max_decode_length,
                batch_size=batch_size,
                beam_size=beam_size,
                decode_alpha=decode_alpha,
                eos_id=eos_id,
                scopes_for_ensemble=all_scopes,
                ensemble=True,
                concat_state_dim=None)
        return seqs, scores
Beispiel #3
0
    def build_beam_search_graph(self,
                                beam_size,
                                batch_size,
                                max_decode_length,
                                decode_alpha=0.6):
        self.inputs = tf.placeholder(tf.int32, [1, None])

        def step(hparams, tokens, past=None):
            lm_output = model.model(hparams=hparams,
                                    X=tokens,
                                    past=past,
                                    reuse=tf.AUTO_REUSE)
            logits = lm_output['logits']
            presents = lm_output['present']
            presents.set_shape(
                model.past_shape(hparams=hparams, batch_size=None))
            return {
                'logits': logits,
                'presents': presents,
            }

        with tf.name_scope('sample_sequence'):
            context_output = step(self.hparams, self.inputs[:, :-1])
            context_state = context_output['presents']
        with tf.name_scope('beam_search'):
            init_seq = tf.expand_dims(self.inputs[:, -1], axis=1)
            seqs, scores = beamsearch.create_inference_graph(
                init_seqs=init_seq,
                state=context_state,
                step_fn=step,
                hparams=self.hparams,
                decode_length=max_decode_length,
                batch_size=batch_size,
                beam_size=beam_size,
                decode_alpha=decode_alpha,
                eos_id=self.eos_id,
                ensemble=False,
                concat_state_dim=-2)

        return seqs, scores
Beispiel #4
0
 def build_beam_search_graph(self,
                             beam_size,
                             batch_size,
                             max_decode_length,
                             decode_alpha=0.6):
     self.def_placeholder_and_components()
     emb_out = []
     enc_h_out = []
     past_for_decoder = []
     for i in range(0, self.input_num):
         past_length = 0
         h = tf.gather(self.wte, self.inputs[i]) + tf.gather(
             self.wpe, positions_for(self.inputs[i], past_length))
         emb_out.append(h)
         presents, h_enc = self.encoder.encode(h, self.input_lens[i])
         enc_h_out.append(h_enc)
         past_for_decoder.append(presents)
     past_length = 0 if enc_h_out[0] is None else tf.shape(enc_h_out[0])[-2]
     self.decoder.sef_var_for_beam_search(past_length,
                                          enc_h_out,
                                          beam_size=beam_size)
     with tf.name_scope('beam_search'):
         init_seq = tf.fill(dims=(batch_size, 1), value=self.sos_id)
         seqs, scores = beamsearch.create_inference_graph(
             init_seqs=init_seq,
             state=past_for_decoder,
             step_fn=self.decoder.decode_one_step,
             hparams=self.hparams,
             decode_length=max_decode_length,
             batch_size=batch_size,
             beam_size=beam_size,
             decode_alpha=decode_alpha,
             eos_id=self.eos_id,
             ensemble=False,
             concat_state_dim=None)
     return seqs, scores