Ejemplo n.º 1
0
    def __init__(self, dtype, shape, send_device, recv_device, name=None):
        """Construct a channel.

    Args:
      dtype: The dtype of tensors sent through the channel.
      shape: The shape of tensors sent through the channel. Must be a fully
        defined shape for TPUs.
      send_device: A fully-specified tensorflow device.
      recv_device: A fully-specified tensorflow device.
      name: A name for the channel (optional).
    """
        current_graph = tf.get_default_graph()
        assert current_graph, "A channel is scoped within a tf.Graph"
        self._dtype = dtype
        self._send_device = send_device
        self._recv_device = recv_device
        self._name = current_graph.unique_name(name if name else "channel")

        assert shape is not None
        shape = tf.TensorShape(shape)

        self._shape = shape
        self._send_tpu_core = _TpuCore(send_device)
        self._recv_tpu_core = _TpuCore(recv_device)
        self._send_called = False
        self._recv_op = None
        assert ((self._send_tpu_core == -1) == (self._recv_tpu_core == -1)), (
            "Mixing TPU and non-TPU: %s and %s" % (send_device, recv_device))
        if self._send_tpu_core >= 0:
            assert self._shape.is_fully_defined(), (
                "TPU channel must have fully defined shape. Name: %s, shape: %s"
                % (self._name, self._shape))
            assert self._send_tpu_core != self._recv_tpu_core, (
                "TPU send/recv must be cross-core: %s and %s" %
                (send_device, recv_device))
Ejemplo n.º 2
0
  def _get_input_shapes(self, *args):
    p = self.params
    if p.nested_map_fprop:
      assert len(args) == 1
      assert isinstance(args[0], py_utils.NestedMap)
      input_tensors = py_utils.Flatten(args[0])
    else:
      input_tensors = _ToTuple(args)
    # Get batch size from the first tensor which is not None.
    mini_batch_size = None
    for input_tensor in input_tensors:
      if input_tensor is not None:
        mini_batch_size = input_tensor.get_shape().as_list()[p.batch_dim]
    assert mini_batch_size is not None
    micro_batch_size = p.micro_batch_size
    if not micro_batch_size:
      if p.num_micro_batches > mini_batch_size:
        p.num_micro_batches = mini_batch_size
      micro_batch_size = mini_batch_size // p.num_micro_batches
    if mini_batch_size is not None:
      if micro_batch_size * p.num_micro_batches != mini_batch_size:
        raise ValueError('micro_batch_size * num_micro_batches != batch_size.')

    input_shapes = ()
    for input_tensor in input_tensors:
      if input_tensor is not None:
        input_shape = input_tensor.get_shape().as_list()
        input_shape[p.batch_dim] = micro_batch_size
        input_shapes += (tf.TensorShape(input_shape),)
      else:
        input_shapes += (None,)

    if p.nested_map_fprop:
      input_shapes = py_utils.Pack(args[0], input_shapes)
    return input_shapes
Ejemplo n.º 3
0
 def ToTensorShape(self):
   """Converts to a possibly partially specified tf.TensorShape."""
   dims = []
   for d in self._shape:
     if d.is_number and d.is_integer:
       dims.append(int(d))
     else:
       dims.append(None)
   return tf.TensorShape(dims)
def _GetShapes(tensors, none_shapes=False):
    """Util for getting nested structure of shapes from structure of tensors.

  Args:
    tensors: Structure of Tensors to get shapes for.
    none_shapes: Returns None shapes if true.

  Returns:
    The same structure as tensors but of corresponding `TensorShape` objects.
  """
    shapes = []
    for t in tf.nest.flatten(tensors):
        shape = t.get_shape() if isinstance(t, tf.Tensor) else None
        if none_shapes:
            if shape:
                shapes.append(tf.TensorShape([None] * len(shape)))
            else:
                shapes.append(tf.TensorShape(None))
        else:
            shapes.append(tf.TensorShape(shape))

    return type(tensors)(tf.nest.pack_sequence_as(tensors, shapes))
Ejemplo n.º 5
0
        def _CellFn(unused_theta, unused_state0, inputs):
            """Recurrent cell function wrapper of body.FProp."""
            # Sets shapes for both theta and inputs to self.body.FProp.
            for dst, src in zip(inputs.args + inputs.theta.Flatten(),
                                list(args) + theta_stack.Flatten()):
                if src is not None:
                    dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

            # Runs the actual body.FProp
            fprop_outputs = self.body.FProp(inputs.theta, *inputs.args)
            fprop_outputs = _ToTuple(fprop_outputs)
            assert len(fprop_outputs) == len(out_shapes)
            # Passes fprop outputs to the next layer through state.
            state1 = py_utils.NestedMap(outputs=list(fprop_outputs))
            return state1, py_utils.NestedMap()
Ejemplo n.º 6
0
 def _AssignVar(self, var_op):
     size = var_op.get_attr('dtype').size
     shape = tf.TensorShape(var_op.get_attr('shape'))
     assert self._var_space_pq, ('No ps devices to use.')
     allocated, device = heapq.heappop(self._var_space_pq)
     if shape.num_elements() is None:
         assert var_op.name.endswith(
             'wb/var'), 'Unexpected name pattern: %s' % var_op.name
         # CuDNN RNN vars shape aren't known statically, decide to make a constant
         # estimate to avoid introducing more complexities.
         allocated += 10 * 1024**2 * size
     else:
         allocated += shape.num_elements() * size
     heapq.heappush(self._var_space_pq, (allocated, device))
     tf.logging.info('Place variable %s on %s %d', var_op.name, device,
                     allocated)
     return device
Ejemplo n.º 7
0
        def _CellFn(unused_theta, state0, theta_i):
            """Recurrent cell function wrapper of body.FProp."""
            # Retrieves fprop arguments from state and sets shapes.
            frop_inputs = _StateToArgs(state0)

            # Sets shapes for theta_i as well.
            for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()):
                if src is not None:
                    dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

            # Runs the actual body.FProp
            frop_outputs = self.body.FProp(theta_i, *frop_inputs)
            frop_outputs = _ToTuple(frop_outputs)
            assert len(frop_outputs) == len(frop_inputs)

            # Passes fprop outputs to the next layer through state.
            state1 = _ArgsToState(frop_outputs)
            return state1, py_utils.NestedMap()
    def GreedySearchDecode(self,
                           theta,
                           encoder_outputs,
                           init_beam_search_state=None,
                           pre_beam_search_step_callback=None,
                           post_beam_search_step_callback=None,
                           max_steps=None):
        """Performs greedy-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as
      src_batch_size.

        - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos>
          token is encountered during search.
        - hyp_lens: [num_hyps].
        - done_hyps: [num_hyps], whether or not an eos is encountered.
    """
        p = self.params
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta,
            encoder_outputs,
            1  # num_hyps_per_beam
        )

        num_hyps = tf.shape(initial_results.log_probs)[0]

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        cur_step = tf.constant(0, dtype=tf.int32)
        done_hyps = inplace_ops.empty(shape=[num_hyps],
                                      dtype=tf.bool,
                                      init=True,
                                      name='done_hyps')
        hyp_lens = inplace_ops.empty(shape=[num_hyps],
                                     dtype=tf.int32,
                                     init=True,
                                     name='hyp_lens')
        hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps],
                                    dtype=tf.int32,
                                    init=True,
                                    name='hyp_ids')

        def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids,
                         unused_hyp_lens, done_hyps, unused_other_states_list):
            return tf.math.logical_and(
                cur_step < max_steps,
                tf.math.logical_not(tf.reduce_all(done_hyps)))

        def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                     other_states_list):
            (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
             new_other_states) = self._GreedySearchStep(
                 theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens,
                 done_hyps, other_states.Pack(other_states_list),
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              tf.TensorShape(hyp_ids.get_shape()),
                              tf.TensorShape(hyp_lens.get_shape()),
                              tf.TensorShape(done_hyps.get_shape()),
                              _GetShapes(flat_other_states, none_shapes=True)))

        # transpose hyp_ids so it matches BeamSearchDecode's output
        final_hyp_ids = tf.transpose(final_hyp_ids)
        return final_hyp_ids, final_hyp_lens, final_done_hyps
    def BeamSearchDecode(self,
                         theta,
                         encoder_outputs,
                         num_hyps_per_beam_override=0,
                         init_beam_search_state=None,
                         pre_beam_search_step_callback=None,
                         post_beam_search_step_callback=None,
                         max_steps=None):
        """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks. Mostly opaque to BeamSearchHelper, except that it should
        contain either a 'seq_lengths' field of shape [source_batch_size] or
        a 'paddings' field of shape [source_max_lengths, source_batch_size].
      num_hyps_per_beam_override: If set to a value <= 0, this parameter is
        ignored. If set to a value > 0, then this value will be used to override
        `p.num_hyps_per_beam`.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A `BeamSearchDecodeOutput`.
    """
        p = self.params
        num_hyps_per_beam = p.num_hyps_per_beam
        if num_hyps_per_beam_override > 0:
            num_hyps_per_beam = num_hyps_per_beam_override
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta, encoder_outputs, num_hyps_per_beam)

        num_hyps = tf.shape(initial_results.log_probs)[0]
        num_beams = num_hyps // num_hyps_per_beam

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        min_score = -1e36
        best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score)
        cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype)
        in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype)
        in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string)
        bs_atten_probs = tf.zeros(
            [max_steps, num_hyps,
             tf.shape(initial_results.atten_probs)[1]],
            dtype=p.dtype)
        cur_step = tf.constant(0, dtype=tf.int32)
        all_done = tf.constant(False, dtype=tf.bool)
        core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps,
                          in_prev_hyps, in_done_hyps, bs_atten_probs)

        def LoopContinue(cur_step, all_done, unused_step_ids,
                         unused_core_bs_states, unused_other_states_list):
            return tf.math.logical_and(cur_step < max_steps,
                                       tf.math.logical_not(all_done))

        def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states,
                     other_states_list):
            (cur_step, all_done, new_step_ids, new_bs_states,
             new_other_states) = self._BeamSearchStep(
                 theta, encoder_outputs, cur_step, step_ids, core_bs_states,
                 other_states.Pack(other_states_list), num_hyps_per_beam,
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, all_done, new_step_ids, new_bs_states,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(all_done.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              _GetShapes(core_bs_states),
                              _GetShapes(flat_other_states, none_shapes=True)))
        # [target_seq_len, num_beams * num_hyps_per_beam].
        final_done_hyps = final_bs_states[5]
        final_other_states = other_states.Pack(flat_final_other_states)

        # Assume that `paddings` has shape [source_max_lengths, source_batch_size]
        # by default, and compute `encoded_seq_lengths` accordingly. This can be
        # overridden by directly passing `seq_lengths` in the `encoder_outputs`
        # NestedMap.
        encoded_seq_lengths = getattr(encoder_outputs, 'seq_lengths', None)
        if encoded_seq_lengths is None:
            source_paddings = encoder_outputs.padding
            if isinstance(source_paddings, py_utils.NestedMap):
                encoded_seq_lengths = tf.cast(
                    tf.round(
                        tf.reduce_sum(
                            1.0 - tf.transpose(source_paddings.Flatten()[0]),
                            1)), tf.int32)
            else:
                encoded_seq_lengths = tf.cast(
                    tf.round(
                        tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)),
                    tf.int32)

        # [num_beams, num_hyps_per_beam].
        topk_hyps = ops.top_k_terminated_hyps(
            final_done_hyps,
            encoded_seq_lengths,
            k=num_hyps_per_beam,
            num_hyps_per_beam=num_hyps_per_beam,
            length_normalization=p.length_normalization,
            coverage_penalty=p.coverage_penalty,
            target_seq_length_ratio=p.target_seq_length_ratio,
            eoc_id=p.target_eoc_id,
            merge_paths=p.merge_paths)
        # [num_beams * num_hyps_per_beam, ...].
        max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps
        topk_ids, topk_lens, topk_scores = ops.unpack_hyp(
            tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length)
        # [num_beams, num_hyps_per_beam].
        topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps))

        return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids,
                                      topk_lens, topk_scores, None,
                                      final_other_states)
Ejemplo n.º 10
0
    def _BeamSearchDecodeIds(self,
                             theta,
                             encoder_outputs,
                             num_hyps_per_beam,
                             init_beam_search_state=None,
                             pre_beam_search_step_callback=None,
                             post_beam_search_step_callback=None,
                             max_steps=None):
        """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap computed by encoder.
      num_hyps_per_beam: Number of hyps per beam.

      init_beam_search_state: The InitBeamSearchState callback. Please refer to
          the class header comments for more details.
      pre_beam_search_step_callback: The PreBeamSearchStepCallback callback.
          Please refer to the class header comments for more details.
      post_beam_search_step_callback: The PostBeamSearchStepCallback callback.
          Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
          self.params.target_seq_len.

    Returns:
      hyps: A tensor of shape [time, b * k] with ids of the token selected.
      prev_hyps: A tensor of shape [time, b * k] with index to the previous hyps
        which was selected.
      done_hyps: A boolean tensor of shape [time, b * k] where value
        indicates if hyps was terminated.
      scores: A tensor of shape [time, b * k] with scores of the token
        selected.
      atten_probs: A tensor of shape [time, b * k, seq_len] which contain the
        attention probabilities over the source words against word in the
        previous hyps.
      eos_scores: A tensor of shape [time, b * k] with scores of the eos token
        selected.
      eos_atten_probs: A tensor of shape [time, b * k, seq_len] which contain
        the attention probabilities over the source words against word in the
        previous hyps.
      source_seq_lengths:  A tensor of shape [time] containing the source
        seq_lengths.
      flat_final_other_states: A array of tensors that are part of other states.
    """
        p = self.params
        source_paddings = encoder_outputs.padding

        initial_results, other_states = init_beam_search_state(
            theta, encoder_outputs, num_hyps_per_beam)

        num_hyps = tf.shape(initial_results.log_probs)[0]
        num_beams = num_hyps // num_hyps_per_beam

        # We cache the NestedMap as member variable so that we can use it to
        # pack the final outputs. Tpu rewrite methods forces us to strictly pass
        # in Tensors, and output Tensors
        self._other_states = other_states

        step_ids = tf.fill([num_hyps, 1],
                           tf.constant(p.target_sos_id, dtype=tf.int32))
        min_score = -1e36
        fprop_dtype = py_utils.FPropDtype(p)
        best_scores = (tf.zeros(shape=[num_beams], dtype=fprop_dtype) +
                       min_score)
        cumulative_scores = tf.zeros(shape=[num_hyps], dtype=fprop_dtype)
        histories = tf.zeros(shape=[num_hyps], dtype=tf.int32)
        in_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        in_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps)
        in_prev_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps)
        in_done_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps)
        in_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        in_eos_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        in_eos_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        cur_step = tf.constant(0, dtype=tf.int32)
        all_done = tf.constant(False, dtype=tf.bool)
        # States for beam search that are inputs into Beam search step.
        accum_bs_states = [best_scores, cumulative_scores, histories]
        # States that are not accumulators.
        non_accum_bs_states = [
            in_scores,
            in_hyps,
            in_prev_hyps,
            in_done_hyps,
            in_atten_probs,
            in_eos_scores,
            in_eos_atten_probs,
        ]
        core_bs_states = tuple(accum_bs_states + non_accum_bs_states)

        flat_other_states = other_states.Flatten()

        # If there is an optimized implementation for short sequence, LoopBodyShort
        # will run first for short_seq_limit steps (after which the
        # LoopBodyShort does not have performance benefit). Then LoopBodyLong (the
        # default implementation) is used to continue the rest of the steps. For
        # decoders which do not have the short sequence specific implementation,
        # only the LoopBodyLong (the default implementation) will run.

        if p.short_seq_limit > 0:

            def LoopContinueShort(cur_step, all_done, unused_step_ids,
                                  unused_core_bs_states,
                                  unused_other_states_list):
                """Use short_seq optimization when cur_step is smaller than limit."""
                return tf.math.logical_and(cur_step < p.short_seq_limit,
                                           tf.math.logical_not(all_done))

            def LoopBodyShort(cur_step, unused_all_done, step_ids,
                              core_bs_states, other_states_list):
                """Loop body of short_seq optimization.

        Instead of doing computation for the entire padded sequence, while loop
        with early exit is used within each _BeamSearchStep to do computation
        for only the actual sequence (seq_length <= cur_step).
        use_short_seq_opt is used as the flag to pass this information down to
        the decoder implementation.

        Args:
          cur_step: A scalar int tensor, the current time step, 0-based.
          unused_all_done: A tf.bool, indicating whether the decoding finishes.
          step_ids: An int32 tensor of shape [num_hyps, 1]. The input ids to the
            current search step.
          core_bs_states: A tuple of core beam search states.
          other_states_list: A flattened NestedMap of other beam search states.

        Returns:
          The updated input tuple, with the same shape.
        """
                (cur_step, all_done, new_step_ids, new_bs_states,
                 new_other_states) = self._BeamSearchStep(
                     theta,
                     encoder_outputs,
                     cur_step,
                     step_ids,
                     core_bs_states,
                     other_states.Pack(other_states_list),
                     num_hyps_per_beam,
                     pre_beam_search_step_callback,
                     post_beam_search_step_callback,
                     use_short_seq_opt=True)
                return (cur_step, all_done, new_step_ids, new_bs_states,
                        new_other_states.Flatten())

            (cur_step, all_done, step_ids, core_bs_states,
             flat_other_states) = tf.while_loop(
                 LoopContinueShort,
                 LoopBodyShort,
                 loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                            flat_other_states),
                 parallel_iterations=10,
                 back_prop=False,
                 swap_memory=False,
                 shape_invariants=(
                     tf.TensorShape(cur_step.get_shape()),
                     tf.TensorShape(all_done.get_shape()),
                     tf.TensorShape(step_ids.get_shape()),
                     tuple(
                         list(_GetShapes(accum_bs_states)) +
                         list(_GetShapes(non_accum_bs_states,
                                         none_shapes=True))),
                     _GetShapes(flat_other_states, none_shapes=True)),
                 maximum_iterations=max_steps)

        def LoopContinueLong(cur_step, all_done, unused_step_ids,
                             unused_core_bs_states, unused_other_states_list):
            """Continue default implementation until decoding finishes."""
            return tf.math.logical_and(cur_step < max_steps,
                                       tf.math.logical_not(all_done))

        def LoopBodyLong(cur_step, unused_all_done, step_ids, core_bs_states,
                         other_states_list):
            """Loop body of default long_seq implementation."""
            (cur_step, all_done, new_step_ids, new_bs_states,
             new_other_states) = self._BeamSearchStep(
                 theta,
                 encoder_outputs,
                 cur_step,
                 step_ids,
                 core_bs_states,
                 other_states.Pack(other_states_list),
                 num_hyps_per_beam,
                 pre_beam_search_step_callback,
                 post_beam_search_step_callback,
                 use_short_seq_opt=False)
            return (cur_step, all_done, new_step_ids, new_bs_states,
                    new_other_states.Flatten())

        _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
            LoopContinueLong,
            LoopBodyLong,
            loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(
                tf.TensorShape(cur_step.get_shape()),
                tf.TensorShape(all_done.get_shape()),
                tf.TensorShape(step_ids.get_shape()),
                tuple(
                    list(_GetShapes(accum_bs_states)) +
                    list(_GetShapes(non_accum_bs_states, none_shapes=True))),
                _GetShapes(flat_other_states, none_shapes=False)),
            maximum_iterations=max_steps)

        if isinstance(source_paddings, py_utils.NestedMap):
            source_seq_lengths = tf.cast(tf.round(
                tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]),
                              1)),
                                         dtype=tf.int32)
        else:
            source_seq_lengths = tf.cast(tf.round(
                tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)),
                                         dtype=tf.int32)

        # Concatenate all outputs on axis=0.
        scores = final_bs_states[3].stack()
        hyps = final_bs_states[4].stack()
        prev_hyps = final_bs_states[5].stack()
        done_hyps = tf.cast(final_bs_states[6].stack(), tf.bool)
        atten_probs = final_bs_states[7].stack()
        eos_scores = final_bs_states[8].stack()
        eos_atten_probs = final_bs_states[9].stack()
        rets = (hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores,
                eos_atten_probs, source_seq_lengths)

        # TODO(rohananil): Only send a single R1 tensor to host instead of 3 after
        # b/111131551 is resolved.
        # Canonical shapes for tensors of various. ranks
        r_shapes = [
            py_utils.GetShape(source_seq_lengths),
            py_utils.GetShape(hyps),
            py_utils.GetShape(atten_probs)
        ]
        # Reshape all tensors to [-1] to avoid cost of copy due to padding.
        rets_r1 = [tf.reshape(r, [-1]) for r in rets]

        return tuple(r_shapes) + tuple(rets_r1) + tuple(
            flat_final_other_states)
Ejemplo n.º 11
0
    def FProp(self, theta, *args):
        """Runs p.repeat copies of self.body.FProp independently.

    Args:
      theta: Layer model parameters. The shape of each variable in theta is
        always [p.repeat, ...]. And the i-th slice theta[i] becomes theta of the
        i-th copy of self.body.
      *args: Input arguments. The shape of each tensor in args is always
        [p.repeat, ....]. And the list [arg[i] for arg in args] becomes inputs
        to the i-th copy of self.body.FProp.

    Returns:
      The accumulated output_tensors. Each tensor t in the return has the shape
      [p.repeat, ....] and the tuple (t[i] for i in output_tensors) is the
      return tuple of the i-th self.body.FProp.
    """
        p = self.params
        for arg in args:
            if arg is not None:
                arg = py_utils.HasShape(arg, [p.repeat], ndims=1)

        theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars,
                                            p.repeat)
        inputs = py_utils.NestedMap(theta=theta_stack, args=list(args))
        # Infer out_shapes from FPropMeta.
        out_shapes = self._InferOutShapes(args)

        def _CellFn(unused_theta, unused_state0, inputs):
            """Recurrent cell function wrapper of body.FProp."""
            # Sets shapes for both theta and inputs to self.body.FProp.
            for dst, src in zip(inputs.args + inputs.theta.Flatten(),
                                list(args) + theta_stack.Flatten()):
                if src is not None:
                    dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

            # Runs the actual body.FProp
            fprop_outputs = self.body.FProp(inputs.theta, *inputs.args)
            fprop_outputs = _ToTuple(fprop_outputs)
            assert len(fprop_outputs) == len(out_shapes)
            # Passes fprop outputs to the next layer through state.
            state1 = py_utils.NestedMap(outputs=list(fprop_outputs))
            return state1, py_utils.NestedMap()

        with tf.name_scope(p.name):
            # Initiate state0 with inferred output shapes.
            state0 = py_utils.NestedMap(outputs=[
                tf.zeros(shape, args[0].dtype) for shape in out_shapes
            ])
            # Runs body.FProp p.repeat times using Recurrent.
            acc_states, _ = recurrent.Recurrent(theta=py_utils.NestedMap(),
                                                state0=state0,
                                                inputs=inputs,
                                                cell_fn=_CellFn)

            # Retrieves fprop outputs from state1 and sets shapes.
            output_tensors = tuple(acc_states.outputs)
            for out_idx in range(len(output_tensors)):
                output_tensors[out_idx].set_shape(
                    tf.TensorShape([p.repeat] + out_shapes[out_idx].as_list()))

            return output_tensors[0] if len(args) == 1 else tuple(
                output_tensors)