def body_infer(time, inputs, caches, outputs_tas, finished, log_probs, lengths, bs_stat_ta, predicted_ids): """Internal while_loop body. Args: time: Scalar int32 Tensor. inputs: A list of inputs Tensors. caches: A dict of decoder states. outputs_tas: A list of TensorArrays. finished: A bool tensor (keeping track of what's finished). log_probs: The log probability Tensor. lengths: The decoding length Tensor. bs_stat_ta: structure of TensorArray. predicted_ids: A Tensor. Returns: `(time + 1, next_inputs, next_caches, next_outputs_tas, next_finished, next_log_probs, next_lengths, next_infer_status_ta)`. """ # step decoder def _decoding(_decoder, _input, _cache, _decoder_output_remover, _outputs_ta, _outputs_to_logits_fn): with tf.variable_scope(_decoder.name): _output, _next_cache = _decoder.step(_input, _cache) _decoder_top_features = _decoder.merge_top_features(_output) _ta = nest.map_structure(lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms), _outputs_ta, _decoder_output_remover.apply(_output)) _logit = _outputs_to_logits_fn(_decoder_top_features) return _output, _next_cache, _ta, _logit outputs, next_caches, next_outputs_tas, logits = repeat_n_times( num_models, _decoding, decoders, inputs, caches, decoder_output_removers, outputs_tas, outputs_to_logits_fns) # sample next symbols sample_ids, beam_ids, next_log_probs, next_lengths \ = helper.sample_symbols(logits, log_probs, finished, lengths, time=time) for c in next_caches: c["decoding_states"] = gather_states(c["decoding_states"], beam_ids) infer_status = BeamSearchStateSpec( log_probs=next_log_probs, beam_ids=beam_ids) bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out), bs_stat_ta, infer_status) predicted_ids = gather_states(tf.reshape(predicted_ids, [-1, time + 1]), beam_ids) next_predicted_ids = tf.concat([predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1) next_predicted_ids = tf.reshape(next_predicted_ids, [-1]) next_predicted_ids.set_shape([None]) next_finished, next_input_symbols = helper.next_symbols(time=time, sample_ids=sample_ids) next_inputs = repeat_n_times(num_models, target_to_embedding_fns, next_input_symbols, time + 1) next_finished = tf.logical_or(next_finished, finished) return time + 1, next_inputs, next_caches, next_outputs_tas, \ next_finished, next_log_probs, next_lengths, bs_stat_ta, \ next_predicted_ids
def map_fn(n, d, p): # n: name prefix # d: data list # p: padding symbol data[concat_name(n, Constants.IDS_NAME)] = d n_samples = len(d) n_devices = len(input_fields) n_samples_per_gpu = n_samples // n_devices if n_samples % n_devices > 0: n_samples_per_gpu += 1 def _feed_batchs(_start_idx, _inpf): if _start_idx * n_samples_per_gpu >= n_samples: return 0 x, x_len = padding_batch_data( d[_start_idx * n_samples_per_gpu:(_start_idx + 1) * n_samples_per_gpu], p) data["feed_dict"][_inpf[concat_name(n, Constants.IDS_NAME)]] = x data["feed_dict"][_inpf[concat_name( n, Constants.LENGTH_NAME)]] = x_len return len(x_len) parallels = repeat_n_times(n_devices, _feed_batchs, list(range(n_devices)), input_fields) data["feed_dict"]["parallels"] = parallels
def evaluate_with_attention(sess, loss_op, eval_data, vocab_source, vocab_target, attention_op=None, output_filename_prefix=None): """ Evaluates data by loss. Args: sess: `tf.Session`. loss_op: Tensorflow operation, computing the loss. eval_data: An iterable instance that each element is a packed feeding dictionary for `sess`. vocab_source: A `Vocab` instance for source side feature map. vocab_target: A `Vocab` instance for target side feature map. attention_op: Tensorflow operation for output attention. output_filename_prefix: A string. Returns: Total loss averaged by number of data samples. """ losses = 0. weights = 0. num_of_samples = 0 attentions = {} for data in eval_data: _n_samples = len(data["feature_ids"]) parallels = data["feed_dict"].pop("parallels") avail = sum(numpy.array(parallels) > 0) if attention_op is None: loss = _evaluate(sess, data["feed_dict"], loss_op[:avail]) else: loss, atts = _evaluate(sess, data["feed_dict"], [loss_op[:avail], attention_op[:avail]]) ss_strs = [ vocab_source.convert_to_wordlist(ss, bpe_decoding=False) for ss in data["feature_ids"] ] tt_strs = [ vocab_target.convert_to_wordlist(tt, bpe_decoding=False, reverse_seq=False) for tt in data["label_ids"] ] _attentions = sum( repeat_n_times(avail, select_attention_sample_by_sample, atts), []) attentions.update( pack_batch_attention_dict(num_of_samples, ss_strs, tt_strs, _attentions)) data["feed_dict"]["parallels"] = parallels losses += sum([_l[0] for _l in loss]) weights += sum([_l[1] for _l in loss]) num_of_samples += _n_samples loss = losses / weights if attention_op is not None: dump_attentions(output_filename_prefix, attentions) return loss
def _add(prefix): nonpadding_tokens_num, total_tokens_num = repeat_n_times( len(input_fields), compute_non_padding_num, input_fields, prefix) nonpadding_tokens_num = tf.reduce_sum(nonpadding_tokens_num) total_tokens_num = tf.reduce_sum(total_tokens_num) tf.add_to_collection(Constants.DISPLAY_KEY_COLLECTION_NAME, "input_stats/{}_nonpadding_tokens_num".format(prefix)) tf.add_to_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME, nonpadding_tokens_num) tf.add_to_collection(Constants.DISPLAY_KEY_COLLECTION_NAME, "input_stats/{}_nonpadding_ratio".format(prefix)) tf.add_to_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME, tf.to_float(nonpadding_tokens_num) / tf.to_float(total_tokens_num))
def evaluate_with_attention( sess, loss_op, eval_data, vocab_source, vocab_target, attention_op=None, output_filename_prefix=None): """ Evaluates data by loss. Args: sess: `tf.Session`. loss_op: Tensorflow operation, computing the loss. eval_data: An iterable instance that each element is a packed feeding dictionary for `sess`. vocab_source: A `Vocab` instance for source side feature map. vocab_target: A `Vocab` instance for target side feature map. attention_op: Tensorflow operation for output attention. output_filename_prefix: A string. Returns: Total loss averaged by number of data samples. """ losses = 0. weights = 0. num_of_samples = 0 attentions = {} for data in eval_data: _n_samples = len(data["feature_ids"]) parallels = data["feed_dict"].pop("parallels") avail = sum(numpy.array(parallels) > 0) if attention_op is None: loss = _evaluate(sess, data["feed_dict"], loss_op[:avail]) else: loss, atts = _evaluate(sess, data["feed_dict"], [loss_op[:avail], attention_op[:avail]]) ss_strs = [vocab_source.convert_to_wordlist(ss, bpe_decoding=False) for ss in data["feature_ids"]] tt_strs = [vocab_target.convert_to_wordlist( tt, bpe_decoding=False, reverse_seq=False) for tt in data["label_ids"]] _attentions = sum(repeat_n_times(avail, select_attention_sample_by_sample, atts), []) attentions.update(pack_batch_attention_dict( num_of_samples, ss_strs, tt_strs, _attentions)) data["feed_dict"]["parallels"] = parallels losses += sum([_l[0] for _l in loss]) weights += sum([_l[1] for _l in loss]) num_of_samples += _n_samples loss = losses / weights if attention_op is not None: dump_attentions(output_filename_prefix, attentions) return loss
def build(self, base_models, vocab_target, input_fields): """ Builds the ensemble model. Args: base_models: A list of `BaseSeq2Seq` instances. vocab_target: An instance of `Vocab`. input_fields: A dict of placeholders. Returns: A dictionary of inference status. """ encoder_outputs = [] # prepare for decoding of each model for index, model in enumerate(base_models): with tf.variable_scope( Constants.ENSEMBLE_VARNAME_PREFIX + str(index)): with tf.variable_scope(model.name): encoder_output = model._encode(input_fields=input_fields) vs_name = tf.get_variable_scope().name model._decoder.name = os.path.join(vs_name, model._decoder.name) model._target_modality.name = os.path.join(vs_name, model._target_modality.name) encoder_outputs.append(encoder_output) helper = BeamFeedback( vocab=vocab_target, batch_size=tf.shape(input_fields[Constants.FEATURE_IDS_NAME])[0], maximum_labels_length=self._maximum_labels_length, beam_size=self._beam_size, alpha=self._length_penalty, ensemble_weight=self.get_ensemble_weights(len(base_models))) decoders, bridges, target_to_emb_fns, outputs_to_logits_fns = \ repeat_n_times( len(base_models), lambda m: (m._decoder, m._encoder_decoder_bridge, m._target_to_embedding_fn, m._outputs_to_logits_fn), base_models) decoding_result = dynamic_ensemble_decode( decoders=decoders, encoder_outputs=encoder_outputs, bridges=bridges, helper=helper, target_to_embedding_fns=target_to_emb_fns, outputs_to_logits_fns=outputs_to_logits_fns, beam_size=self._beam_size) predict_out = process_beam_predictions( decoding_result=decoding_result, beam_size=self._beam_size, alpha=self._length_penalty) predict_out["source"] = input_fields[Constants.FEATURE_IDS_NAME] return predict_out
def build(self, input_fields): """ Builds the ensemble model. Args: input_fields: A dict of placeholders. Returns: A dictionary of inference status. """ encoder_outputs = [] # prepare for decoding of each model for index, model in enumerate(self._base_models): encoder_output = model._encode(input_fields=input_fields) encoder_outputs.append(encoder_output) helper = BeamFeedback( vocab=self._vocab_target, batch_size=tf.shape(input_fields[Constants.FEATURE_IDS_NAME])[0], maximum_labels_length=self._maximum_labels_length, beam_size=self._beam_size, alpha=self._length_penalty, ensemble_weight=self.get_ensemble_weights(len(self._base_models))) decoders, bridges, target_to_emb_fns, outputs_to_logits_fns = \ repeat_n_times( len(self._base_models), lambda m: (m._decoder, m._encoder_decoder_bridge, m._target_to_embedding_fn, m._outputs_to_logits_fn), self._base_models) decoding_result = dynamic_ensemble_decode( decoders=decoders, encoder_outputs=encoder_outputs, bridges=bridges, helper=helper, target_to_embedding_fns=target_to_emb_fns, outputs_to_logits_fns=outputs_to_logits_fns, beam_size=self._beam_size) predict_out = process_beam_predictions( decoding_result=decoding_result, beam_size=self._beam_size, alpha=self._length_penalty) predict_out["source"] = input_fields[Constants.FEATURE_IDS_NAME] return predict_out
def build(self, input_fields): """ Builds the ensemble model. Args: input_fields: A dict of placeholders. Returns: A dictionary of inference status. """ encoder_outputs = [] # prepare for decoding of each model for index, model in enumerate(self._base_models): encoder_output = model._encode(input_fields=input_fields) encoder_outputs.append(encoder_output) helper = BeamFeedback( vocab=self._vocab_target, batch_size=tf.shape(input_fields[Constants.FEATURE_IDS_NAME])[0], maximum_labels_length=self._maximum_labels_length, beam_size=self._beam_size, alpha=self._length_penalty, ensemble_weight=self.get_ensemble_weights(len(self._base_models))) decoders, bridges, target_to_emb_fns, outputs_to_logits_fns = \ repeat_n_times( len(self._base_models), lambda m: (m._decoder, m._encoder_decoder_bridge, m._target_to_embedding_fn, m._outputs_to_logits_fn), self._base_models) decoding_result = dynamic_ensemble_decode( decoders=decoders, encoder_outputs=encoder_outputs, bridges=bridges, helper=helper, target_to_embedding_fns=target_to_emb_fns, outputs_to_logits_fns=outputs_to_logits_fns, beam_size=self._beam_size) predict_out = process_beam_predictions(decoding_result=decoding_result, beam_size=self._beam_size, alpha=self._length_penalty) predict_out["source"] = input_fields[Constants.FEATURE_IDS_NAME] return predict_out
def dynamic_ensemble_decode(decoders, encoder_outputs, bridges, helper, target_to_embedding_fns, outputs_to_logits_fns, parallel_iterations=32, swap_memory=False, **kwargs): """ Performs dynamic decoding with `decoders`. Calls prepare() once and step() repeatedly on `Decoder` object. Args: decoders: A list of `Decoder` instances. encoder_outputs: A list of `collections.namedtuple`s from each corresponding `Encoder.encode()`. bridges: A list of `Bridge` instances or Nones. helper: An instance of `Feedback` that samples next symbols from logits. target_to_embedding_fns: A list of callables, converts target ids to embeddings. outputs_to_logits_fns: A list of callables, converts decoder outputs to logits. parallel_iterations: Argument passed to `tf.while_loop`. swap_memory: Argument passed to `tf.while_loop`. kwargs: Returns: The results of inference, an instance of `collections.namedtuple` whose element types are defined by `BeamSearchStateSpec`, indicating the status of beam search. """ num_models = len(decoders) var_scope = tf.get_variable_scope() # Properly cache variable values inside the while_loop if var_scope.caching_device is None: var_scope.set_caching_device(lambda op: op.device) def _create_ta(d): return tf.TensorArray(dtype=d, clear_after_read=False, size=0, dynamic_size=True) decoder_output_removers = repeat_n_times( num_models, lambda dec: DecoderOutputRemover( dec.mode, dec.output_dtype._fields, dec.output_ignore_fields), decoders) # initialize first inputs (start of sentence) with shape [_batch*_beam,] initial_finished, initial_input_symbols = helper.init_symbols() initial_time = tf.constant(0, dtype=tf.int32) initial_inputs = repeat_n_times(num_models, target_to_embedding_fns, initial_input_symbols, initial_time) assert "beam_size" in kwargs beam_size = kwargs["beam_size"] def _create_cache(_decoder, _encoder_output, _bridge): with tf.variable_scope(_decoder.name): _init_cache = _decoder.prepare(_encoder_output, _bridge, helper) _init_cache = stack_beam_size(_init_cache, beam_size) return _init_cache initial_caches = repeat_n_times(num_models, _create_cache, decoders, encoder_outputs, bridges) initial_outputs_tas = [ nest.map_structure( _create_ta, _decoder_output_remover.apply(_decoder.output_dtype)) for _decoder_output_remover, _decoder in zip(decoder_output_removers, decoders) ] def body_infer(time, inputs, caches, outputs_tas, finished, log_probs, lengths, bs_stat_ta, predicted_ids): """Internal while_loop body. Args: time: Scalar int32 Tensor. inputs: A list of inputs Tensors. caches: A dict of decoder states. outputs_tas: A list of TensorArrays. finished: A bool tensor (keeping track of what's finished). log_probs: The log probability Tensor. lengths: The decoding length Tensor. bs_stat_ta: structure of TensorArray. predicted_ids: A Tensor. Returns: `(time + 1, next_inputs, next_caches, next_outputs_tas, next_finished, next_log_probs, next_lengths, next_infer_status_ta)`. """ # step decoder def _decoding(_decoder, _input, _cache, _decoder_output_remover, _outputs_ta, _outputs_to_logits_fn): with tf.variable_scope(_decoder.name): _output, _next_cache = _decoder.step(_input, _cache) _decoder_top_features = _decoder.merge_top_features(_output) _ta = nest.map_structure( lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms), _outputs_ta, _decoder_output_remover.apply(_output)) _logit = _outputs_to_logits_fn(_decoder_top_features) return _output, _next_cache, _ta, _logit outputs, next_caches, next_outputs_tas, logits = repeat_n_times( num_models, _decoding, decoders, inputs, caches, decoder_output_removers, outputs_tas, outputs_to_logits_fns) # sample next symbols sample_ids, beam_ids, next_log_probs, next_lengths \ = helper.sample_symbols(logits, log_probs, finished, lengths, time=time) for c in next_caches: c["decoding_states"] = gather_states(c["decoding_states"], beam_ids) infer_status = BeamSearchStateSpec(log_probs=next_log_probs, beam_ids=beam_ids) bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out), bs_stat_ta, infer_status) predicted_ids = gather_states( tf.reshape(predicted_ids, [-1, time + 1]), beam_ids) next_predicted_ids = tf.concat( [predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1) next_predicted_ids = tf.reshape(next_predicted_ids, [-1]) next_predicted_ids.set_shape([None]) next_finished, next_input_symbols = helper.next_symbols( time=time, sample_ids=sample_ids) next_inputs = repeat_n_times(num_models, target_to_embedding_fns, next_input_symbols, time + 1) next_finished = tf.logical_or(next_finished, finished) return time + 1, next_inputs, next_caches, next_outputs_tas, \ next_finished, next_log_probs, next_lengths, bs_stat_ta, \ next_predicted_ids initial_log_probs = tf.zeros_like(initial_input_symbols, dtype=tf.float32) initial_lengths = tf.zeros_like(initial_input_symbols, dtype=tf.int32) initial_bs_stat_ta = nest.map_structure(_create_ta, BeamSearchStateSpec.dtypes()) initial_input_symbols.set_shape([None]) loop_vars = [ initial_time, initial_inputs, initial_caches, initial_outputs_tas, initial_finished, # infer vars initial_log_probs, initial_lengths, initial_bs_stat_ta, initial_input_symbols ] res = tf.while_loop(lambda *args: tf.logical_not(tf.reduce_all(args[4])), body_infer, loop_vars=loop_vars, parallel_iterations=parallel_iterations, swap_memory=swap_memory) timesteps = res[0] + 1 log_probs, length, bs_stat, predicted_ids = res[-4:] final_bs_stat = nest.map_structure(lambda ta: ta.stack(), bs_stat) return { "beam_ids": final_bs_stat.beam_ids, "log_probs": final_bs_stat.log_probs, "decoding_length": length, "hypothesis": tf.reshape(predicted_ids, [-1, timesteps])[:, 1:] }
def dynamic_ensemble_decode( decoders, encoder_outputs, bridges, helper, target_to_embedding_fns, outputs_to_logits_fns, parallel_iterations=32, swap_memory=False, **kwargs): """ Performs dynamic decoding with `decoders`. Calls prepare() once and step() repeatedly on `Decoder` object. Args: decoders: A list of `Decoder` instances. encoder_outputs: A list of `collections.namedtuple`s from each corresponding `Encoder.encode()`. bridges: A list of `Bridge` instances or Nones. helper: An instance of `Feedback` that samples next symbols from logits. target_to_embedding_fns: A list of callables, converts target ids to embeddings. outputs_to_logits_fns: A list of callables, converts decoder outputs to logits. parallel_iterations: Argument passed to `tf.while_loop`. swap_memory: Argument passed to `tf.while_loop`. kwargs: Returns: The results of inference, an instance of `collections.namedtuple` whose element types are defined by `BeamSearchStateSpec`, indicating the status of beam search. """ num_models = len(decoders) var_scope = tf.get_variable_scope() # Properly cache variable values inside the while_loop if var_scope.caching_device is None: var_scope.set_caching_device(lambda op: op.device) def _create_ta(d): return tf.TensorArray( dtype=d, clear_after_read=False, size=0, dynamic_size=True) decoder_output_removers = repeat_n_times( num_models, lambda dec: DecoderOutputRemover( dec.mode, dec.output_dtype._fields, dec.output_ignore_fields), decoders) # initialize first inputs (start of sentence) with shape [_batch*_beam,] initial_finished, initial_input_symbols = helper.init_symbols() initial_time = tf.constant(0, dtype=tf.int32) initial_inputs = repeat_n_times( num_models, target_to_embedding_fns, initial_input_symbols, initial_time) assert "beam_size" in kwargs beam_size = kwargs["beam_size"] def _create_cache(_decoder, _encoder_output, _bridge): with tf.variable_scope(_decoder.name): _init_cache = _decoder.prepare(_encoder_output, _bridge, helper) _init_cache = stack_beam_size(_init_cache, beam_size) return _init_cache initial_caches = repeat_n_times( num_models, _create_cache, decoders, encoder_outputs, bridges) initial_outputs_tas = [nest.map_structure( _create_ta, _decoder_output_remover.apply(_decoder.output_dtype)) for _decoder_output_remover, _decoder in zip(decoder_output_removers, decoders)] def body_infer(time, inputs, caches, outputs_tas, finished, log_probs, lengths, bs_stat_ta, predicted_ids): """Internal while_loop body. Args: time: Scalar int32 Tensor. inputs: A list of inputs Tensors. caches: A dict of decoder states. outputs_tas: A list of TensorArrays. finished: A bool tensor (keeping track of what's finished). log_probs: The log probability Tensor. lengths: The decoding length Tensor. bs_stat_ta: structure of TensorArray. predicted_ids: A Tensor. Returns: `(time + 1, next_inputs, next_caches, next_outputs_tas, next_finished, next_log_probs, next_lengths, next_infer_status_ta)`. """ # step decoder def _decoding(_decoder, _input, _cache, _decoder_output_remover, _outputs_ta, _outputs_to_logits_fn): with tf.variable_scope(_decoder.name): _output, _next_cache = _decoder.step(_input, _cache) _decoder_top_features = _decoder.merge_top_features(_output) _ta = nest.map_structure(lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms), _outputs_ta, _decoder_output_remover.apply(_output)) _logit = _outputs_to_logits_fn(_decoder_top_features) return _output, _next_cache, _ta, _logit outputs, next_caches, next_outputs_tas, logits = repeat_n_times( num_models, _decoding, decoders, inputs, caches, decoder_output_removers, outputs_tas, outputs_to_logits_fns) # sample next symbols sample_ids, beam_ids, next_log_probs, next_lengths \ = helper.sample_symbols(logits, log_probs, finished, lengths, time=time) for c in next_caches: c["decoding_states"] = gather_states(c["decoding_states"], beam_ids) infer_status = BeamSearchStateSpec( log_probs=next_log_probs, beam_ids=beam_ids) bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out), bs_stat_ta, infer_status) predicted_ids = gather_states(tf.reshape(predicted_ids, [-1, time + 1]), beam_ids) next_predicted_ids = tf.concat([predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1) next_predicted_ids = tf.reshape(next_predicted_ids, [-1]) next_predicted_ids.set_shape([None]) next_finished, next_input_symbols = helper.next_symbols(time=time, sample_ids=sample_ids) next_inputs = repeat_n_times(num_models, target_to_embedding_fns, next_input_symbols, time + 1) next_finished = tf.logical_or(next_finished, finished) return time + 1, next_inputs, next_caches, next_outputs_tas, \ next_finished, next_log_probs, next_lengths, bs_stat_ta, \ next_predicted_ids initial_log_probs = tf.zeros_like(initial_input_symbols, dtype=tf.float32) initial_lengths = tf.zeros_like(initial_input_symbols, dtype=tf.int32) initial_bs_stat_ta = nest.map_structure(_create_ta, BeamSearchStateSpec.dtypes()) initial_input_symbols.set_shape([None]) loop_vars = [initial_time, initial_inputs, initial_caches, initial_outputs_tas, initial_finished, # infer vars initial_log_probs, initial_lengths, initial_bs_stat_ta, initial_input_symbols] res = tf.while_loop( lambda *args: tf.logical_not(tf.reduce_all(args[4])), body_infer, loop_vars=loop_vars, parallel_iterations=parallel_iterations, swap_memory=swap_memory) timesteps = res[0] + 1 log_probs, length, bs_stat, predicted_ids = res[-4:] final_bs_stat = nest.map_structure(lambda ta: ta.stack(), bs_stat) return {"beam_ids": final_bs_stat.beam_ids, "log_probs": final_bs_stat.log_probs, "decoding_length": length, "hypothesis": tf.reshape(predicted_ids, [-1, timesteps])[:, 1:]}
def _infer(sess, feed_dict, prediction_op, batch_size, top_k=1, output_attention=False): """ Infers a batch of samples with beam search. Args: sess: `tf.Session` feed_dict: A dictionary of feeding data. prediction_op: Tensorflow operation for inference. batch_size: The batch size. top_k: An integer, number of predicted sequences will be returned. output_attention: Whether to output attention. Returns: A tuple `(predicted_sequences, attention_scores)`. The `predicted_sequences` is a list of hypothesis with approx [`top_k` * `batch_sze`, sequence_length]. The `attention_scores` is None if there is no attention related information in `prediction_op`. """ parallels = feed_dict.pop("parallels") avail = sum(numpy.array(parallels) > 0) extract_keys = ["sorted_hypothesis"] if output_attention: assert top_k == 1 extract_keys.extend(["sorted_argidx", "attentions", "beam_ids"]) brief_pred_op = dict( zip( extract_keys, repeat_n_times(avail, lambda dd: tuple([dd[k] for k in extract_keys]), prediction_op[:avail]))) predict_out = sess.run(brief_pred_op, feed_dict=feed_dict) feed_dict["parallels"] = parallels total_samples = sum( repeat_n_times(avail, lambda p: p.shape[0], predict_out["sorted_hypothesis"])) beam_size = total_samples // batch_size def _post_process_hypo(pred, **kwargs): _num_samples = pred.shape[0] _batch_size = _num_samples // beam_size batch_beam_pos = numpy.tile( numpy.arange(_batch_size) * beam_size, [beam_size, 1]).transpose() batch_beam_pos = numpy.reshape(batch_beam_pos[:, :top_k], -1) if output_attention: atts = postprocess_attention( beam_ids=kwargs["beam_ids"], attention_dict=kwargs["attentions"], gather_idx=kwargs["sorted_argidx"][batch_beam_pos]) return pred[batch_beam_pos, :].tolist(), atts return pred[batch_beam_pos, :].tolist(), [] hypothesises, attentions = repeat_n_times( avail, _post_process_hypo, predict_out["sorted_hypothesis"], beam_ids=predict_out.get("beam_ids", None), attentions=predict_out.get("attentions", None), sorted_argidx=predict_out.get("sorted_argidx", None)) hypothesis = sum(hypothesises, []) attention = sum(attentions, []) return hypothesis, attention
def _infer( sess, feed_dict, prediction_op, batch_size, top_k=1, output_attention=False): """ Infers a batch of samples with beam search. Args: sess: `tf.Session` feed_dict: A dictionary of feeding data. prediction_op: Tensorflow operation for inference. batch_size: The batch size. top_k: An integer, number of predicted sequences will be returned. output_attention: Whether to output attention. Returns: A tuple `(predicted_sequences, attention_scores)`. The `predicted_sequences` is a list of hypothesis with approx [`top_k` * `batch_sze`, sequence_length]. The `attention_scores` is None if there is no attention related information in `prediction_op`. """ parallels = feed_dict.pop("parallels") avail = sum(numpy.array(parallels) > 0) extract_keys = ["sorted_hypothesis"] if output_attention: assert top_k == 1 extract_keys.extend(["sorted_argidx", "attentions", "beam_ids"]) brief_pred_op = dict(zip( extract_keys, repeat_n_times( avail, lambda dd: tuple([dd[k] for k in extract_keys]), prediction_op[:avail]))) predict_out = sess.run(brief_pred_op, feed_dict=feed_dict) feed_dict["parallels"] = parallels total_samples = sum( repeat_n_times(avail, lambda p: p.shape[0], predict_out["sorted_hypothesis"])) beam_size = total_samples // batch_size def _post_process_hypo(pred, **kwargs): _num_samples = pred.shape[0] _batch_size = _num_samples // beam_size batch_beam_pos = numpy.tile(numpy.arange(_batch_size) * beam_size, [beam_size, 1]).transpose() batch_beam_topk_add = numpy.tile(numpy.arange(top_k), [batch_beam_pos.shape[0], 1]) batch_beam_pos = numpy.reshape( batch_beam_pos[:, :top_k] + batch_beam_topk_add, -1) if output_attention: atts = postprocess_attention( beam_ids=kwargs["beam_ids"], attention_dict=kwargs["attentions"], gather_idx=kwargs["sorted_argidx"][batch_beam_pos]) return pred[batch_beam_pos, :].tolist(), atts return pred[batch_beam_pos, :].tolist(), [] hypothesises, attentions = repeat_n_times( avail, _post_process_hypo, predict_out["sorted_hypothesis"], beam_ids=predict_out.get("beam_ids", None), attentions=predict_out.get("attentions", None), sorted_argidx=predict_out.get("sorted_argidx", None)) hypothesis = sum(hypothesises, []) attention = sum(attentions, []) return hypothesis, attention