Пример #1
0
 def _cell_fn(theta, state0, acc_state, acc_gate, i):
   """RNN cell function."""
   input_slice = {k: tf.gather(inputs[k], i) for k in inputs}
   state1, gate = cell_fn(theta, state0, input_slice)
   for k in state0:
     if k not in skipped_state:
       acc_state[k] = tf.stop_gradient(
           inplace_ops.alias_inplace_update(acc_state[k], i, state1[k]))
   acc_gate = tf.stop_gradient(
       inplace_ops.alias_inplace_update(acc_gate, i, gate))
   return theta, state1, acc_state, acc_gate, i - 1 if reverse else i + 1
Пример #2
0
 def _cell_grad_fn_with_state0(state0, theta, dy, dstate1, dtheta,
                               dinput, i):
     """Gradient cell function."""
     state0 = {
         k: tf.stop_gradient(state0[k])
         for k in state0 if k not in skipped_state
     }
     theta = {k: tf.stop_gradient(theta[k]) for k in theta}
     if "padding" in inputs:
         inputs_slice = {"padding": tf.gather(inputs["padding"], i)}
     else:
         inputs_slice = None
     gate = tf.gather(acc_gate, i)
     for k in dy:
         dstate1[k] = dstate1[k] + tf.gather(dy[k], i)
     dt, dstate, di = cell_grad(theta, state0, inputs_slice, gate,
                                dstate1)
     dtheta = {
         k: dtheta[k] + dt[k]
         for k in dtheta if k not in skipped_theta
     }
     dinput = {
         k: inplace_ops.alias_inplace_update(dinput[k], i, di[k])
         for k in di
     }
     return theta, dy, dstate, dtheta, dinput, i + 1 if reverse else i - 1
    def _GreedySearchStep(self, theta, encoder_outputs, cur_step, step_ids,
                          hyp_ids, hyp_lens, done_hyps, other_states,
                          pre_beam_search_step_callback,
                          post_beam_search_step_callback):
        """Extend greedy search hyps for one step.

    Args:
      theta: A `.NestedMap` object containing weights' values of the decoder
        layer and its children layers.
      encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to
        the callbacks.
      cur_step: A scalar int tensor, the current time step, 0-based.
      step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the
        current search step.
      hyp_ids: An int tensor of shape [num_hyps, tgt_seq_len].
      hyp_lens: Valid length of all the hyps. Tokens after eos ids are not
        counted.
      done_hyps: Whether or not a hyp has finished.
      other_states: A `.NestedMap` of other beam search states. This
        `.NestedMap` is managed and updated by the client. It is expected that
        each of its member tensors are of rank >= 1. t[i, ...] is the state of
        the i-th hyp at the beginning of this search step.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        See class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        See class header comments for more details.

    Returns:
      A tuple of following elements for the next greedy search step,
      (next step, new_step_ids, hyp_ids, hyp_lens, done_hyps, other_states)
    """
        p = self.params
        # Increment hyp_lens by 1 if the hyp is not finished yet.
        hyp_lens = hyp_lens + (1 - tf.cast(done_hyps, tf.int32))

        bs_results, new_other_states = pre_beam_search_step_callback(
            theta, encoder_outputs, step_ids, other_states,
            1)  # num_hyps_per_beam
        new_step_ids = tf.math.argmax(bs_results.log_probs, 1)
        new_step_ids = tf.cast(new_step_ids, tf.int32)
        new_step_ids = tf.reshape(new_step_ids, tf.shape(step_ids))
        final_other_states = post_beam_search_step_callback(
            theta, encoder_outputs, new_step_ids, new_other_states)

        # Stash new_step_ids into the right slot.
        new_step_ids_1d = tf.reshape(new_step_ids, [-1])
        hyp_ids = inplace_ops.alias_inplace_update(hyp_ids, cur_step,
                                                   new_step_ids_1d)
        # Update done_hyps if the current step_ids is the end of sequence token.
        done_hyps = tf.math.logical_or(
            done_hyps, tf.equal(new_step_ids_1d, p.target_eos_id))

        return (cur_step + 1, new_step_ids, hyp_ids, hyp_lens, done_hyps,
                final_other_states)
Пример #4
0
                        def _GatherStep(x_in, t):
                            """Gather for one time step.

              Args:
                x_in: in the shape of [T, B, ...] we first get slice(t) from the
                  tensors, then gather old_hyp_ids from the slice and write the
                  interpolated slice inplace to update the original x_in.
                t: current time step

              Returns:
                Updated x_in and time step
              """
                            x = tf.gather(tf.gather(x_in, t),
                                          correct_old_hyp_ids)
                            return inplace_ops.alias_inplace_update(
                                x_in, t, x), t + 1
Пример #5
0
        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """Internal while_loop body.

      Args:
        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: bool tensor (keeping track of what's finished).
        sequence_lengths: int32 tensor (keeping track of time of finish).

      Returns:
        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
          next_sequence_lengths)`.
        ```
      """
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs, state)
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
            else:
                next_finished = tf.logical_or(decoder_finished, finished)
            next_sequence_lengths = tf.where(
                tf.logical_not(finished),
                tf.fill(tf.shape(sequence_lengths), time + 1),
                sequence_lengths)

            contrib_framework.nest.assert_same_structure(state, decoder_state)
            contrib_framework.nest.assert_same_structure(
                outputs_ta, next_outputs)
            contrib_framework.nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = contrib_framework.nest.map_structure(
                    lambda out, zero: tf.where(finished, zero, out),
                    next_outputs, zero_outputs)
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tf.TensorArray):
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else tf.where(finished, cur, new)

            if impute_finished:
                next_state = contrib_framework.nest.map_structure(
                    _maybe_copy_state, decoder_state, state)
            else:
                next_state = decoder_state

            outputs_ta = contrib_framework.nest.map_structure(
                lambda ta, out: inplace_ops.alias_inplace_update(
                    ta, time, out), outputs_ta, emit)
            return (time + 1, outputs_ta, next_state, next_inputs,
                    next_finished, next_sequence_lengths)