def _cell_fn(theta, state0, acc_state, acc_gate, i): """RNN cell function.""" input_slice = {k: tf.gather(inputs[k], i) for k in inputs} state1, gate = cell_fn(theta, state0, input_slice) for k in state0: if k not in skipped_state: acc_state[k] = tf.stop_gradient( inplace_ops.alias_inplace_update(acc_state[k], i, state1[k])) acc_gate = tf.stop_gradient( inplace_ops.alias_inplace_update(acc_gate, i, gate)) return theta, state1, acc_state, acc_gate, i - 1 if reverse else i + 1
def _cell_grad_fn_with_state0(state0, theta, dy, dstate1, dtheta, dinput, i): """Gradient cell function.""" state0 = { k: tf.stop_gradient(state0[k]) for k in state0 if k not in skipped_state } theta = {k: tf.stop_gradient(theta[k]) for k in theta} if "padding" in inputs: inputs_slice = {"padding": tf.gather(inputs["padding"], i)} else: inputs_slice = None gate = tf.gather(acc_gate, i) for k in dy: dstate1[k] = dstate1[k] + tf.gather(dy[k], i) dt, dstate, di = cell_grad(theta, state0, inputs_slice, gate, dstate1) dtheta = { k: dtheta[k] + dt[k] for k in dtheta if k not in skipped_theta } dinput = { k: inplace_ops.alias_inplace_update(dinput[k], i, di[k]) for k in di } return theta, dy, dstate, dtheta, dinput, i + 1 if reverse else i - 1
def _GreedySearchStep(self, theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states, pre_beam_search_step_callback, post_beam_search_step_callback): """Extend greedy search hyps for one step. Args: theta: A `.NestedMap` object containing weights' values of the decoder layer and its children layers. encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to the callbacks. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. hyp_ids: An int tensor of shape [num_hyps, tgt_seq_len]. hyp_lens: Valid length of all the hyps. Tokens after eos ids are not counted. done_hyps: Whether or not a hyp has finished. other_states: A `.NestedMap` of other beam search states. This `.NestedMap` is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. See class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. See class header comments for more details. Returns: A tuple of following elements for the next greedy search step, (next step, new_step_ids, hyp_ids, hyp_lens, done_hyps, other_states) """ p = self.params # Increment hyp_lens by 1 if the hyp is not finished yet. hyp_lens = hyp_lens + (1 - tf.cast(done_hyps, tf.int32)) bs_results, new_other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, 1) # num_hyps_per_beam new_step_ids = tf.math.argmax(bs_results.log_probs, 1) new_step_ids = tf.cast(new_step_ids, tf.int32) new_step_ids = tf.reshape(new_step_ids, tf.shape(step_ids)) final_other_states = post_beam_search_step_callback( theta, encoder_outputs, new_step_ids, new_other_states) # Stash new_step_ids into the right slot. new_step_ids_1d = tf.reshape(new_step_ids, [-1]) hyp_ids = inplace_ops.alias_inplace_update(hyp_ids, cur_step, new_step_ids_1d) # Update done_hyps if the current step_ids is the end of sequence token. done_hyps = tf.math.logical_or( done_hyps, tf.equal(new_step_ids_1d, p.target_eos_id)) return (cur_step + 1, new_step_ids, hyp_ids, hyp_lens, done_hyps, final_other_states)
def _GatherStep(x_in, t): """Gather for one time step. Args: x_in: in the shape of [T, B, ...] we first get slice(t) from the tensors, then gather old_hyp_ids from the slice and write the interpolated slice inplace to update the original x_in. t: current time step Returns: Updated x_in and time step """ x = tf.gather(tf.gather(x_in, t), correct_old_hyp_ids) return inplace_ops.alias_inplace_update( x_in, t, x), t + 1
def body(time, outputs_ta, state, inputs, finished, sequence_lengths): """Internal while_loop body. Args: time: scalar int32 tensor. outputs_ta: structure of TensorArray. state: (structure of) state tensors and TensorArrays. inputs: (structure of) input tensors. finished: bool tensor (keeping track of what's finished). sequence_lengths: int32 tensor (keeping track of time of finish). Returns: `(time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)`. ``` """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) if decoder.tracks_own_finished: next_finished = decoder_finished else: next_finished = tf.logical_or(decoder_finished, finished) next_sequence_lengths = tf.where( tf.logical_not(finished), tf.fill(tf.shape(sequence_lengths), time + 1), sequence_lengths) contrib_framework.nest.assert_same_structure(state, decoder_state) contrib_framework.nest.assert_same_structure( outputs_ta, next_outputs) contrib_framework.nest.assert_same_structure(inputs, next_inputs) # Zero out output values past finish if impute_finished: emit = contrib_framework.nest.map_structure( lambda out, zero: tf.where(finished, zero, out), next_outputs, zero_outputs) else: emit = next_outputs # Copy through states past finish def _maybe_copy_state(new, cur): # TensorArrays and scalar states get passed through. if isinstance(cur, tf.TensorArray): pass_through = True else: new.set_shape(cur.shape) pass_through = (new.shape.ndims == 0) return new if pass_through else tf.where(finished, cur, new) if impute_finished: next_state = contrib_framework.nest.map_structure( _maybe_copy_state, decoder_state, state) else: next_state = decoder_state outputs_ta = contrib_framework.nest.map_structure( lambda ta, out: inplace_ops.alias_inplace_update( ta, time, out), outputs_ta, emit) return (time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)