def body_infer(time, inputs, caches, outputs_tas, finished, log_probs, lengths, infer_status_ta): """Internal while_loop body. Args: time: Scalar int32 Tensor. inputs: A list of inputs Tensors. caches: A dict of decoder states. outputs_tas: A list of TensorArrays. finished: A bool tensor (keeping track of what's finished). log_probs: The log probability Tensor. lengths: The decoding length Tensor. infer_status_ta: structure of TensorArray. Returns: `(time + 1, next_inputs, next_caches, next_outputs_tas, next_finished, next_log_probs, next_lengths, next_infer_status_ta)`. """ # step decoder outputs = [] next_caches = [] for dec, inp, cache in zip(decoders, inputs, caches): with tf.variable_scope(dec.name): out, next_cache = dec.step(inp, cache) outputs.append(out) next_caches.append(next_cache) next_outputs_tas = [] for out_ta, out, rem in zip(outputs_tas, outputs, decoder_output_removers): ta = nest.map_structure(lambda ta, out: ta.write(time, out), out_ta, rem.apply(out)) next_outputs_tas.append(ta) logits = [] for dec, modality, out in zip(decoders, target_modalities, outputs): logits.append(_compute_logits(dec, modality, out)) # sample next symbols sample_ids, beam_ids, next_log_probs, next_lengths \ = helper.sample_symbols(logits, log_probs, finished, lengths, time=time) for c in next_caches: c["decoding_states"] = gather_states(c["decoding_states"], beam_ids) infer_status = BeamSearchStateSpec(log_probs=next_log_probs, predicted_ids=sample_ids, beam_ids=beam_ids, lengths=next_lengths) infer_status_ta = nest.map_structure( lambda ta, out: ta.write(time, out), infer_status_ta, infer_status) next_finished, next_input_symbols = helper.next_symbols( time=time, sample_ids=sample_ids) next_inputs = nest.map_structure( lambda modality: _embed_words(modality, next_input_symbols, time + 1), target_modalities) next_finished = tf.logical_or(next_finished, finished) return time + 1, next_inputs, next_caches, next_outputs_tas, \ next_finished, next_log_probs, next_lengths, infer_status_ta
def dynamic_ensemble_decode(decoders, encoder_outputs, bridges, target_modalities, helper, parallel_iterations=32, swap_memory=False): """ Performs dynamic decoding with `decoders`. Calls prepare() once and step() repeatedly on `Decoder` object. Args: decoders: A list of `Decoder` instances. encoder_outputs: A list of `collections.namedtuple`s from each corresponding `Encoder.encode()`. bridges: A list of `Bridge` instances or Nones. target_modalities: A list of `Modality` instances. helper: An instance of `Feedback` that samples next symbols from logits. parallel_iterations: Argument passed to `tf.while_loop`. swap_memory: Argument passed to `tf.while_loop`. Returns: The results of inference, an instance of `collections.namedtuple` whose element types are defined by `BeamSearchStateSpec`, indicating the status of beam search. """ var_scope = tf.get_variable_scope() # Properly cache variable values inside the while_loop if var_scope.caching_device is None: var_scope.set_caching_device(lambda op: op.device) def _create_ta(d): return tf.TensorArray(dtype=d, clear_after_read=False, size=0, dynamic_size=True) decoder_output_removers = nest.map_structure( lambda dec: DecoderOutputRemover(dec.mode, dec.output_dtype._fields, dec.output_ignore_fields), decoders) # initialize first inputs (start of sentence) with shape [_batch*_beam,] initial_finished, initial_input_symbols = helper.init_symbols() initial_time = tf.constant(0, dtype=tf.int32) initial_input_symbols_embed = nest.map_structure( lambda modality: _embed_words(modality, initial_input_symbols, initial_time), target_modalities) inputs_preprocessing_fns = [] inputs_postprocessing_fns = [] initial_inputs = [] initial_decoder_states = [] decoding_params = [] for dec, enc_out, bri, inp in zip(decoders, encoder_outputs, bridges, initial_input_symbols_embed): with tf.variable_scope(dec.name): inputs_preprocessing_fn, inputs_postprocessing_fn = dec.inputs_prepost_processing_fn( ) inputs = inputs_postprocessing_fn(None, inp) dec_states, dec_params = dec.prepare(enc_out, bri, helper) # prepare decoder dec_states = stack_beam_size(dec_states, helper.beam_size) dec_params = stack_beam_size(dec_params, helper.beam_size) # add to list inputs_preprocessing_fns.append(inputs_preprocessing_fn) inputs_postprocessing_fns.append(inputs_postprocessing_fn) initial_inputs.append(inputs) initial_decoder_states.append(dec_states) decoding_params.append(dec_params) initial_outputs_tas = nest.map_structure( lambda dec_out_rem, dec: nest.map_structure( _create_ta, dec_out_rem.apply(dec.output_dtype)), decoder_output_removers, decoders) def body_infer(time, inputs, decoder_states, outputs_tas, finished, log_probs, lengths, infer_status_ta): """Internal while_loop body. Args: time: Scalar int32 Tensor. inputs: A list of inputs Tensors. decoder_states: A list of decoder states. outputs_tas: A list of TensorArrays. finished: A bool tensor (keeping track of what's finished). log_probs: The log probability Tensor. lengths: The decoding length Tensor. infer_status_ta: structure of TensorArray. Returns: `(time + 1, next_inputs, next_decoder_states, next_outputs_tas, next_finished, next_log_probs, next_lengths, next_infer_status_ta)`. """ # step decoder outputs = [] cur_inputs = [] next_decoder_states = [] for dec, inp, pre_fn, stat, dec_params in \ zip(decoders, inputs, inputs_preprocessing_fns, decoder_states, decoding_params): with tf.variable_scope(dec.name): inp = pre_fn(time, inp) out, next_stat = dec.step(inp, stat, dec_params) cur_inputs.append(inp) outputs.append(out) next_decoder_states.append(next_stat) next_outputs_tas = [] for out_ta, out, rem in zip(outputs_tas, outputs, decoder_output_removers): ta = nest.map_structure(lambda ta, out: ta.write(time, out), out_ta, rem.apply(out)) next_outputs_tas.append(ta) logits = [] for dec, modality, out in zip(decoders, target_modalities, outputs): logits.append(_compute_logits(dec, modality, out)) # sample next symbols sample_ids, beam_ids, next_log_probs, next_lengths \ = helper.sample_symbols(logits, log_probs, finished, lengths, time=time) gathered_states = [] for next_stat in next_decoder_states: gathered_states.append(gather_states(next_stat, beam_ids)) cur_inputs = nest.map_structure( lambda inp: gather_states(inp, beam_ids), cur_inputs) infer_status = BeamSearchStateSpec(log_probs=next_log_probs, predicted_ids=sample_ids, beam_ids=beam_ids, lengths=next_lengths) infer_status_ta = nest.map_structure( lambda ta, out: ta.write(time, out), infer_status_ta, infer_status) next_finished, next_input_symbols = helper.next_symbols( time=time, sample_ids=sample_ids) next_inputs_embed = nest.map_structure( lambda modality: _embed_words(modality, next_input_symbols, time + 1), target_modalities) next_finished = tf.logical_or(next_finished, finished) next_inputs = [] for dec, cur_inp, next_inp, post_fn in zip(decoders, cur_inputs, next_inputs_embed, inputs_postprocessing_fns): with tf.variable_scope(dec.name): next_inputs.append(post_fn(cur_inp, next_inp)) return time + 1, next_inputs, gathered_states, next_outputs_tas, \ next_finished, next_log_probs, next_lengths, infer_status_ta initial_log_probs = tf.zeros_like(initial_input_symbols, dtype=tf.float32) initial_lengths = tf.zeros_like(initial_input_symbols, dtype=tf.int32) initial_infer_status_ta = nest.map_structure(_create_ta, BeamSearchStateSpec.dtypes()) loop_vars = [ initial_time, initial_inputs, initial_decoder_states, initial_outputs_tas, initial_finished, # infer vars initial_log_probs, initial_lengths, initial_infer_status_ta ] res = tf.while_loop(lambda *args: tf.logical_not(tf.reduce_all(args[4])), body_infer, loop_vars=loop_vars, parallel_iterations=parallel_iterations, swap_memory=swap_memory) final_infer_status = nest.map_structure(lambda ta: ta.stack(), res[-1]) return final_infer_status
def dynamic_ensemble_decode(decoders, encoder_outputs, bridges, target_modalities, helper, parallel_iterations=32, swap_memory=False, **kwargs): """ Performs dynamic decoding with `decoders`. Calls prepare() once and step() repeatedly on `Decoder` object. Args: decoders: A list of `Decoder` instances. encoder_outputs: A list of `collections.namedtuple`s from each corresponding `Encoder.encode()`. bridges: A list of `Bridge` instances or Nones. target_modalities: A list of `Modality` instances. helper: An instance of `Feedback` that samples next symbols from logits. parallel_iterations: Argument passed to `tf.while_loop`. swap_memory: Argument passed to `tf.while_loop`. kwargs: Returns: The results of inference, an instance of `collections.namedtuple` whose element types are defined by `BeamSearchStateSpec`, indicating the status of beam search. """ var_scope = tf.get_variable_scope() # Properly cache variable values inside the while_loop if var_scope.caching_device is None: var_scope.set_caching_device(lambda op: op.device) def _create_ta(d): return tf.TensorArray(dtype=d, clear_after_read=False, size=0, dynamic_size=True) decoder_output_removers = nest.map_structure( lambda dec: DecoderOutputRemover(dec.mode, dec.output_dtype._fields, dec.output_ignore_fields), decoders) # initialize first inputs (start of sentence) with shape [_batch*_beam,] initial_finished, initial_input_symbols = helper.init_symbols() initial_time = tf.constant(0, dtype=tf.int32) initial_inputs = nest.map_structure( lambda modality: _embed_words(modality, initial_input_symbols, initial_time), target_modalities) assert "beam_size" in kwargs beam_size = kwargs["beam_size"] initial_caches = [] for dec, enc_out, bri in zip(decoders, encoder_outputs, bridges): with tf.variable_scope(dec.name): init_cache = dec.prepare(enc_out, bri, helper) # prepare decoder init_cache = stack_beam_size(init_cache, beam_size) initial_caches.append(init_cache) initial_outputs_tas = nest.map_structure( lambda dec_out_rem, dec: nest.map_structure( _create_ta, dec_out_rem.apply(dec.output_dtype)), decoder_output_removers, decoders) def body_infer(time, inputs, caches, outputs_tas, finished, log_probs, lengths, bs_stat_ta, predicted_ids): """Internal while_loop body. Args: time: Scalar int32 Tensor. inputs: A list of inputs Tensors. caches: A dict of decoder states. outputs_tas: A list of TensorArrays. finished: A bool tensor (keeping track of what's finished). log_probs: The log probability Tensor. lengths: The decoding length Tensor. bs_stat_ta: structure of TensorArray. predicted_ids: A Tensor. Returns: `(time + 1, next_inputs, next_caches, next_outputs_tas, next_finished, next_log_probs, next_lengths, next_infer_status_ta)`. """ # step decoder outputs = [] next_caches = [] for dec, inp, cache in zip(decoders, inputs, caches): with tf.variable_scope(dec.name): out, next_cache = dec.step(inp, cache) outputs.append(out) next_caches.append(next_cache) next_outputs_tas = [] for out_ta, out, rem in zip(outputs_tas, outputs, decoder_output_removers): ta = nest.map_structure(lambda ta, out: ta.write(time, out), out_ta, rem.apply(out)) next_outputs_tas.append(ta) logits = [] for dec, modality, out in zip(decoders, target_modalities, outputs): logits.append(_compute_logits(dec, modality, out)) # sample next symbols sample_ids, beam_ids, next_log_probs, next_lengths \ = helper.sample_symbols(logits, log_probs, finished, lengths, time=time) for c in next_caches: c["decoding_states"] = gather_states(c["decoding_states"], beam_ids) infer_status = BeamSearchStateSpec(log_probs=next_log_probs, beam_ids=beam_ids) bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out), bs_stat_ta, infer_status) predicted_ids = gather_states( tf.reshape(predicted_ids, [-1, time + 1]), beam_ids) next_predicted_ids = tf.concat( [predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1) next_predicted_ids = tf.reshape(next_predicted_ids, [-1]) next_predicted_ids.set_shape([None]) next_finished, next_input_symbols = helper.next_symbols( time=time, sample_ids=sample_ids) next_inputs = nest.map_structure( lambda modality: _embed_words(modality, next_input_symbols, time + 1), target_modalities) next_finished = tf.logical_or(next_finished, finished) return time + 1, next_inputs, next_caches, next_outputs_tas, \ next_finished, next_log_probs, next_lengths, bs_stat_ta, \ next_predicted_ids initial_log_probs = tf.zeros_like(initial_input_symbols, dtype=tf.float32) initial_lengths = tf.zeros_like(initial_input_symbols, dtype=tf.int32) initial_bs_stat_ta = nest.map_structure(_create_ta, BeamSearchStateSpec.dtypes()) initial_input_symbols.set_shape([None]) loop_vars = [ initial_time, initial_inputs, initial_caches, initial_outputs_tas, initial_finished, # infer vars initial_log_probs, initial_lengths, initial_bs_stat_ta, initial_input_symbols ] res = tf.while_loop(lambda *args: tf.logical_not(tf.reduce_all(args[4])), body_infer, loop_vars=loop_vars, parallel_iterations=parallel_iterations, swap_memory=swap_memory) timesteps = res[0] + 1 log_probs, length, bs_stat, predicted_ids = res[-4:] final_bs_stat = nest.map_structure(lambda ta: ta.stack(), bs_stat) return { "beam_ids": final_bs_stat.beam_ids, "log_probs": final_bs_stat.log_probs, "decoding_length": length, "hypothesis": tf.reshape(predicted_ids, [-1, timesteps])[:, 1:] }