Exemplo n.º 1
0
    def _tower_infer_graph(features):
        infer_fns = []
        for midx, (graph, params) in enumerate(zip(total_graphs,
                                                   total_params)):
            params = copy.copy(params)
            params.scope_name = params.scope_name + "_ensembler_%d" % midx
            infer_fns.append(graph.infer_fn(params))

        total_encoding_fns, total_decoding_fns = list(zip(*infer_fns))

        def _encoding_fn(source):
            model_state = {}
            for _midx in range(len(total_encoding_fns)):
                current_model_state = total_encoding_fns[_midx](source)
                model_state['ensembler_%d' % _midx] = current_model_state
            return model_state

        def _decoding_fn(target, model_state, time):
            pred_logits = []

            for _midx in range(len(total_decoding_fns)):
                state_describ = "ensembler_%d" % _midx
                if default_params.search_mode == "cache":
                    current_output = total_decoding_fns[_midx](
                        target, model_state[state_describ], time)
                else:
                    current_output = total_decoding_fns[_midx](target,
                                                               model_state,
                                                               time)
                step_logits, step_state = current_output

                pred_logits.append(step_logits)

                if default_params.search_mode == "cache":
                    model_state[state_describ] = step_state

            model_logits = tf.add_n(
                [tf.nn.softmax(logits)
                 for logits in pred_logits]) / len(pred_logits)

            return tf.log(model_logits), model_state

        beam_output = beam_search(features, _encoding_fn, _decoding_fn,
                                  default_params)

        return beam_output
Exemplo n.º 2
0
def infer_search(src_tokenizer, dst_tokenizer, transformer, config, methord='beam_search'):
	

	
	if methord == 'beam_search':
		_, y_outputs, _, x_placeholder = beam_search(batch_size=1, beam_width=FLAGS.beam_width,
		                                                vocab_size=config.vocab_size, max_len=FLAGS.max_len,
		                                                hidden_size=config.hidden_size,
		                                                sos_id=dst_tokenizer.bos_id(),
		                                                eos_id=dst_tokenizer.eos_id(),
		                                                inst=transformer)
	elif methord == 'greedy_search':
		_, y_outputs, x_placeholder = greedy_search(batch_size=1,
		                                                  max_len=FLAGS.max_len,
		                                                  sos_id=dst_tokenizer.bos_id(),
		                                                  eos_id=dst_tokenizer.eos_id(),
		                                                  inst=transformer)
	else:
		raise NotImplementedError('尚未支持')

	sess = tf.Session()
	saver = tf.train.Saver()
	model_file = tf.train.latest_checkpoint(FLAGS.model_dir)
	saver.restore(sess=sess, save_path=model_file)

	fpw  = open(file=FLAGS.infer_file + '.dst', mode='w', encoding='utf-8')
	with open(file=FLAGS.infer_file, mode='r', encoding='utf-8') as fp:
		for line in fp:
			line = line.strip()
			idxs = src_tokenizer.encode_as_ids(input=line)
			idxs = idxs[:FLAGS.max_len-1]
			idxs.append(src_tokenizer.eos_id())
			for i in range(len(idxs), FLAGS.max_len):
				idxs.append(0)
			y_idxs, = sess.run(
				fetches=[y_outputs],
				feed_dict={
					x_placeholder: [idxs]
				}
			)
			y_idxs_val = dst_tokenizer.decode_ids(ids=y_idxs[0].tolist())
			fpw.write(y_idxs_val + '\n')
	fpw.close()
Exemplo n.º 3
0
 def _tower_infer_graph(features):
     encoding_fn, decoding_fn = graph.infer_fn(params)
     beam_output = beam_search(features, encoding_fn, decoding_fn, params)
     return beam_output
Exemplo n.º 4
0
    def run_batch(
        self,
        batch: Batch,
        recognition_beam_size: int = 1,
        translation_beam_size: int = 1,
        translation_beam_alpha: float = -1,
        translation_max_output_length: int = 100,
    ) -> (np.array, np.array, np.array):
        """
        Get outputs and attentions scores for a given batch

        :param batch: batch to generate hypotheses for
        :param recognition_beam_size: size of the beam for CTC beam search
            if 1 use greedy
        :param translation_beam_size: size of the beam for translation beam search
            if 1 use greedy
        :param translation_beam_alpha: alpha value for beam search
        :param translation_max_output_length: maximum length of translation hypotheses
        :return: stacked_output: hypotheses for batch,
            stacked_attention_scores: attention scores for batch
        """

        encoder_output, encoder_hidden = self.encode(
            sgn=batch.sgn, sgn_mask=batch.sgn_mask, sgn_length=batch.sgn_lengths
        )

        if self.do_recognition:
            # Gloss Recognition Part
            # N x T x C
            gloss_scores = self.gloss_output_layer(encoder_output)
            # N x T x C
            gloss_probabilities = gloss_scores.log_softmax(2)
            # Turn it into T x N x C
            gloss_probabilities = gloss_probabilities.permute(1, 0, 2)
            gloss_probabilities = gloss_probabilities.cpu().detach().numpy()
            tf_gloss_probabilities = np.concatenate(
                (gloss_probabilities[:, :, 1:], gloss_probabilities[:, :, 0, None]),
                axis=-1,
            )

            assert recognition_beam_size > 0
            ctc_decode, _ = tf.nn.ctc_beam_search_decoder(
                inputs=tf_gloss_probabilities,
                sequence_length=batch.sgn_lengths.cpu().detach().numpy(),
                beam_width=recognition_beam_size,
                top_paths=1,
            )
            ctc_decode = ctc_decode[0]
            # Create a decoded gloss list for each sample
            tmp_gloss_sequences = [[] for i in range(gloss_scores.shape[0])]
            for (value_idx, dense_idx) in enumerate(ctc_decode.indices):
                tmp_gloss_sequences[dense_idx[0]].append(
                    ctc_decode.values[value_idx].numpy() + 1
                )
            decoded_gloss_sequences = []
            for seq_idx in range(0, len(tmp_gloss_sequences)):
                decoded_gloss_sequences.append(
                    [x[0] for x in groupby(tmp_gloss_sequences[seq_idx])]
                )
        else:
            decoded_gloss_sequences = None

        if self.do_translation:
            # greedy decoding
            if translation_beam_size < 2:
                stacked_txt_output, stacked_attention_scores = greedy(
                    encoder_hidden=encoder_hidden,
                    encoder_output=encoder_output,
                    src_mask=batch.sgn_mask,
                    embed=self.txt_embed,
                    bos_index=self.txt_bos_index,
                    eos_index=self.txt_eos_index,
                    decoder=self.decoder,
                    max_output_length=translation_max_output_length,
                )
                # batch, time, max_sgn_length
            else:  # beam size
                stacked_txt_output, stacked_attention_scores = beam_search(
                    size=translation_beam_size,
                    encoder_hidden=encoder_hidden,
                    encoder_output=encoder_output,
                    src_mask=batch.sgn_mask,
                    embed=self.txt_embed,
                    max_output_length=translation_max_output_length,
                    alpha=translation_beam_alpha,
                    eos_index=self.txt_eos_index,
                    pad_index=self.txt_pad_index,
                    bos_index=self.txt_bos_index,
                    decoder=self.decoder,
                )
        else:
            stacked_txt_output = stacked_attention_scores = None

        return decoded_gloss_sequences, stacked_txt_output, stacked_attention_scores