Exemple #1
0
 def _TransposeAttentions(x):
     return tf.transpose(x, [1, 0, 2])
Exemple #2
0
  def FProp(self, theta, input_batch, interpolation_batch=None, lambdas=None):
    # pyformat: disable
    """Interpolates source ids in input_batch and interpolation_batch.

    Refer to Eq. (4) in paper https://arxiv.org/abs/2106.04060.
    It is a standard Transformer Encoder if interpolation_batch != None.

    Args:
      theta: A `.NestedMap` object containing weights values of this layer and
        its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task ids
          of shape [batch, time].
      interpolation_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task ids
          of shape [batch, time].
        - embs: Embeddings of ids.
      lambdas: A pair of tensors to combine embeddings of ids in input_batch and
        interpolation_batch.

    Returns:
      A NestedMap of

        - encoded: The encoded features, either a tensor of shape
          [time, batch, depth], or a list of tensors if is_transparent is set in
          transformer_stack.
        - padding: of shape [time, batch]
        - segment_id: [time, batch] if packed inputs are supported by the model
          (and all layers), or None otherwise.
        - embedded_inputs: [time, batch, depth] embedded inputs tokens without
          positional encodings.
    """
    # pyformat: enable

    p = self.params
    with tf.name_scope(p.name):
      src_segment_id = None
      src_segment_pos = None
      input_ids = py_utils.with_dependencies([
          py_utils.assert_shape_match(
              tf.shape(input_batch.ids), tf.shape(input_batch.paddings)),
          py_utils.assert_equal(tf.rank(input_batch.ids), 2)
      ], input_batch.ids)

      max_seq_length = None
      if (not py_utils.use_tpu() and
          FLAGS.transformer_encoder_truncates_inputs):
        max_seq_length = tf.cast(
            tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)),
            tf.int32)
        paddings = py_utils.with_dependencies([
            py_utils.assert_equal(
                tf.constant(True, tf.bool),
                tf.reduce_all(input_batch.paddings[:, max_seq_length:] > 0.5))
        ], input_batch.paddings)
        input_ids = input_ids[:, :max_seq_length]
        paddings = paddings[:, :max_seq_length]
        if p.packed_input:
          src_segment_id = input_batch.segment_ids[:, :max_seq_length]
          src_segment_pos = input_batch.segment_pos[:, :max_seq_length]
      else:
        paddings = input_batch.paddings
        if p.packed_input:
          src_segment_id = input_batch.segment_ids
          src_segment_pos = input_batch.segment_pos

      max_time = tf.shape(input_ids)[1]

      # Input token embeddings + positional embeddings
      if not p.shared_emb:
        input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                              tf.reshape(input_ids, [-1]))
      else:
        input_embs = self.softmax.EmbLookup(theta.softmax,
                                            tf.reshape(input_ids, [-1]))

      if interpolation_batch is not None:
        other_input_ids = interpolation_batch.ids
        if not p.shared_emb:
          other_input_embs = self.token_emb.EmbLookup(
              theta.token_emb, tf.reshape(other_input_ids, [-1]))
        else:
          other_input_embs = self.softmax.EmbLookup(
              theta.softmax, tf.reshape(other_input_ids, [-1]))
        lambdas = [tf.expand_dims(a, -1) for a in lambdas]
        if 'embs' in input_batch and input_batch.embs is not None:
          input_embs = input_batch.embs
        if 'embs' in interpolation_batch and interpolation_batch.embs is not None:
          other_input_embs = interpolation_batch.embs
        else:
          input_embs = tf.reshape(
              input_embs,
              [-1, tf.shape(input_ids)[1], p.token_emb.embedding_dim])
          other_input_embs = tf.reshape(
              other_input_embs,
              [-1, tf.shape(other_input_ids)[1], p.token_emb.embedding_dim])
        input_embs = lambdas[0] * input_embs + lambdas[1] * other_input_embs
        paddings = paddings + interpolation_batch.paddings - 1.0
        paddings = tf.clip_by_value(paddings, 0.0, 1.0)

      input_embs = tf.reshape(input_embs,
                              [-1, max_time, p.token_emb.embedding_dim])

      orig_input_embs = input_embs
      if p.task_emb:
        if interpolation_batch is None:
          input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                                input_batch.task_ids)
        else:
          task_embs = self.task_emb.EmbLookup(theta.task_emb,
                                              input_batch.task_ids)
          other_task_embs = self.task_emb.EmbLookup(
              theta.task_emb, interpolation_batch.task_ids)
          task_embs = lambdas[0] * task_embs + lambdas[1] * other_task_embs
          input_embs += task_embs

      if p.packed_input:
        position_embs = self.position_emb.FPropWithPosition(
            theta.position_emb, src_segment_pos)
      else:
        position_embs = self.position_emb.FProp(theta.position_emb, max_time)
        position_embs = tf.reshape(position_embs,
                                   [1, max_time, p.token_emb.embedding_dim])
      input_embs += position_embs

      if p.model_dim != p.token_emb.embedding_dim:
        input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

      paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
      if p.packed_input:
        src_segment_id = tf.transpose(src_segment_id)

      input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs)

      # [time, batch, dim]
      transformer_input = tf.transpose(input_embs, [1, 0, 2])

    if not self.do_eval and p.apply_source_mask:
      # Augment padding for masked source word positions.
      dtype = paddings.dtype
      source_mask = tf.where(
          tf.equal(input_ids, p.source_mask_id),
          tf.ones_like(input_ids, dtype=dtype),
          tf.zeros_like(input_ids, dtype=dtype))
      # Make sure padding is between 0 and 1.
      paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0,
                                  1.0)

    encoded, padding, segment_id = self.transformer_stack.FProp(
        theta.transformer_stack, transformer_input, paddings, src_segment_id)

    return py_utils.NestedMap(
        encoded=encoded,
        padding=padding,
        segment_id=segment_id,
        embedded_inputs=orig_input_embs)
    def GreedySearchDecode(self,
                           theta,
                           encoder_outputs,
                           init_beam_search_state=None,
                           pre_beam_search_step_callback=None,
                           post_beam_search_step_callback=None,
                           max_steps=None):
        """Performs greedy-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as
      src_batch_size.

        - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos>
          token is encountered during search.
        - hyp_lens: [num_hyps].
        - done_hyps: [num_hyps], whether or not an eos is encountered.
    """
        p = self.params
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta,
            encoder_outputs,
            1  # num_hyps_per_beam
        )

        num_hyps = tf.shape(initial_results.log_probs)[0]

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        cur_step = tf.constant(0, dtype=tf.int32)
        done_hyps = inplace_ops.empty(shape=[num_hyps],
                                      dtype=tf.bool,
                                      init=True,
                                      name='done_hyps')
        hyp_lens = inplace_ops.empty(shape=[num_hyps],
                                     dtype=tf.int32,
                                     init=True,
                                     name='hyp_lens')
        hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps],
                                    dtype=tf.int32,
                                    init=True,
                                    name='hyp_ids')

        def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids,
                         unused_hyp_lens, done_hyps, unused_other_states_list):
            return tf.logical_and(cur_step < max_steps,
                                  tf.logical_not(tf.reduce_all(done_hyps)))

        def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                     other_states_list):
            (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
             new_other_states) = self._GreedySearchStep(
                 theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens,
                 done_hyps, other_states.Pack(other_states_list),
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              tf.TensorShape(hyp_ids.get_shape()),
                              tf.TensorShape(hyp_lens.get_shape()),
                              tf.TensorShape(done_hyps.get_shape()),
                              _GetShapes(flat_other_states, none_shapes=True)))

        # transpose hyp_ids so it matches BeamSearchDecode's output
        final_hyp_ids = tf.transpose(final_hyp_ids)
        return final_hyp_ids, final_hyp_lens, final_done_hyps
Exemple #4
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].

    Returns:
      A NestedMap containing:
        - encoded: The encoded features, either a tensor of shape [time, batch,
            depth], or a list of tensors if is_transparent is set in
            transformer_stack.
        - padding: of shape [time, batch]
        - segment_id: [time, batch] if packed inputs are supported by the model
            (and all layers), or None otherwise.
        - embedded_inputs: [time, batch, depth] embedded inputs tokens without
            positional encodings.
    """

        p = self.params
        with tf.name_scope(p.name):
            src_segment_id = None
            src_segment_pos = None
            input_ids = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings)),
                py_utils.assert_equal(tf.rank(input_batch.ids), 2)
            ], input_batch.ids)

            if (not py_utils.use_tpu()
                    and tf.flags.FLAGS.transformer_encoder_truncates_inputs):
                max_seq_length = tf.cast(
                    tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings,
                                                1)), tf.int32)
                paddings = py_utils.with_dependencies([
                    py_utils.assert_equal(
                        tf.constant(True, tf.bool),
                        tf.reduce_all(
                            input_batch.paddings[:, max_seq_length:] > 0.5))
                ], input_batch.paddings)
                input_ids = input_ids[:, :max_seq_length]
                paddings = paddings[:, :max_seq_length]
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids[:, :
                                                             max_seq_length]
                    src_segment_pos = input_batch.segment_pos[:, :
                                                              max_seq_length]
            else:
                paddings = input_batch.paddings
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids
                    src_segment_pos = input_batch.segment_pos

            max_time = tf.shape(input_ids)[1]

            # Input token embeddings + positional embeddings
            input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                                  tf.reshape(input_ids, [-1]))
            input_embs = tf.reshape(input_embs,
                                    [-1, max_time, p.token_emb.embedding_dim])
            # [time, batch, dim]
            orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

            if p.packed_input:
                position_embs = self.position_emb.FPropWithPosition(
                    theta.position_emb, src_segment_pos)
            else:
                position_embs = self.position_emb.FProp(
                    theta.position_emb, max_time)
                position_embs = tf.reshape(
                    position_embs, [1, max_time, p.token_emb.embedding_dim])
            input_embs += position_embs

            if p.model_dim != p.token_emb.embedding_dim:
                input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

            paddings = tf.transpose(paddings)
            if p.packed_input:
                src_segment_id = tf.transpose(src_segment_id)
            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)

            # [time, batch, dim]
            transformer_input = tf.transpose(input_embs, [1, 0, 2])

        encoded, padding, segment_id = self.transformer_stack.FProp(
            theta.transformer_stack, transformer_input, paddings,
            src_segment_id)
        return py_utils.NestedMap(encoded=encoded,
                                  padding=padding,
                                  segment_id=segment_id,
                                  embedded_inputs=orig_input_embs)
Exemple #5
0
    def _InferenceSubgraph_Default(self):
        """Default inference subgraph.

    Returns:
      (fetches, feeds), with:

      - fetches: A dictionary of fetches, containing:

        - log_pplx_per_token: A matrix of shape [batch, time]. [i, j]
          is i-th input text's j-th token's log prob.
        - paddings: A matrix of shape [batch, time]. The padding mask.
        - log_pplx_per_sample: A vector of shape [batch]. [i]
          is i-th input text's log prob.
        - num_oovs_per_sample: A vector of shape [batch] counting the total
          number of out-of-vocabulary tokens in each input.
        - tokens_from_labels: A vector of shape [batch] returning the predicted
          tokens as a sequence after mapping them back to strings from ids using
          the vocabulary.
        - ids: A matrix of shape [batch, time]. [i, j]
          is i-th input text's j-th token's id.

      - feeds: A dictionary of feeds, containing:

        - text: A placeholder for a vector of strings.
    """
        text = tf.placeholder(tf.string, shape=[None])
        # [batch, time]
        ids, labels, paddings = self.input_generator.StringsToIds(text)
        lengths = tf.reduce_sum(tf.to_int32(1 - paddings), axis=1)
        tokens_from_labels = self.input_generator.IdsToStrings(labels, lengths)
        oovs = tf.equal(labels, self.input_generator.tokenizer.unk_id)
        num_oovs_per_sample = tf.to_int32(
            tf.round(tf.reduce_sum(tf.to_float(oovs) * (1 - paddings),
                                   axis=1)))
        # [time, batch]
        ids, paddings, labels, weights = self._TrimIfPossibleThenTranspose(
            ids, paddings, labels, 1.0 - paddings)
        batch_size = tf.shape(ids)[1]
        xent_output, _ = self.lm.FPropDefaultTheta(
            inputs=ids,
            paddings=paddings,
            state0=self.lm.zero_state(self.theta.lm, batch_size),
            labels=py_utils.NestedMap(class_ids=labels, class_weights=weights))

        per_example_xent = py_utils.HasShape(xent_output.per_example_xent,
                                             tf.shape(ids))
        log_pplx_per_sample = tf.reduce_sum(per_example_xent * (1 - paddings),
                                            axis=0)
        fetches = {
            'log_pplx_per_token':  # [batch, time]
            tf.transpose(per_example_xent),
            'paddings':  # [batch, time]
            tf.transpose(paddings),
            'lengths':  # [batch]
            lengths,
            'log_pplx_per_sample':  # [batch]
            log_pplx_per_sample,
            'num_oovs_per_sample':  # [batch], int32
            num_oovs_per_sample,
            'tokens_from_labels':  # [batch], string
            tokens_from_labels,
            'ids':  # [batch, time], int32
            ids
        }
        feeds = {'text': text}
        return fetches, feeds
    def BeamSearchDecode(self,
                         theta,
                         encoder_outputs,
                         num_hyps_per_beam_override=0,
                         init_beam_search_state=None,
                         pre_beam_search_step_callback=None,
                         post_beam_search_step_callback=None,
                         max_steps=None):
        """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks. Mostly opaque to BeamSearchHelper, except that it should
        contain either a 'seq_lengths' field of shape [source_batch_size] or
        a 'paddings' field of shape [source_max_lengths, source_batch_size].
      num_hyps_per_beam_override: If set to a value <= 0, this parameter is
        ignored. If set to a value > 0, then this value will be used to override
        `p.num_hyps_per_beam`.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A `BeamSearchDecodeOutput`.
    """
        p = self.params
        num_hyps_per_beam = p.num_hyps_per_beam
        if num_hyps_per_beam_override > 0:
            num_hyps_per_beam = num_hyps_per_beam_override
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta, encoder_outputs, num_hyps_per_beam)

        num_hyps = tf.shape(initial_results.log_probs)[0]
        num_beams = num_hyps // num_hyps_per_beam

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        min_score = -1e36
        best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score)
        cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype)
        in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype)
        in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string)
        bs_atten_probs = tf.zeros(
            [max_steps, num_hyps,
             tf.shape(initial_results.atten_probs)[1]],
            dtype=p.dtype)
        cur_step = tf.constant(0, dtype=tf.int32)
        all_done = tf.constant(False, dtype=tf.bool)
        core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps,
                          in_prev_hyps, in_done_hyps, bs_atten_probs)

        def LoopContinue(cur_step, all_done, unused_step_ids,
                         unused_core_bs_states, unused_other_states_list):
            return tf.logical_and(cur_step < max_steps,
                                  tf.logical_not(all_done))

        def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states,
                     other_states_list):
            (cur_step, all_done, new_step_ids, new_bs_states,
             new_other_states) = self._BeamSearchStep(
                 theta, encoder_outputs, cur_step, step_ids, core_bs_states,
                 other_states.Pack(other_states_list), num_hyps_per_beam,
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, all_done, new_step_ids, new_bs_states,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(all_done.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              _GetShapes(core_bs_states),
                              _GetShapes(flat_other_states, none_shapes=True)))
        # [target_seq_len, num_beams * num_hyps_per_beam].
        final_done_hyps = final_bs_states[5]
        final_other_states = other_states.Pack(flat_final_other_states)

        # Assume that `paddings` has shape [source_max_lengths, source_batch_size]
        # by default, and compute `encoded_seq_lengths` accordingly. This can be
        # overridden by directly passing `seq_lengths` in the `encoder_outputs`
        # NestedMap.
        encoded_seq_lengths = getattr(encoder_outputs, 'seq_lengths', None)
        if encoded_seq_lengths is None:
            source_paddings = encoder_outputs.padding
            if isinstance(source_paddings, py_utils.NestedMap):
                encoded_seq_lengths = tf.cast(
                    tf.round(
                        tf.reduce_sum(
                            1.0 - tf.transpose(source_paddings.Flatten()[0]),
                            1)), tf.int32)
            else:
                encoded_seq_lengths = tf.cast(
                    tf.round(
                        tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)),
                    tf.int32)

        # [num_beams, num_hyps_per_beam].
        topk_hyps = ops.top_k_terminated_hyps(
            final_done_hyps,
            encoded_seq_lengths,
            k=num_hyps_per_beam,
            num_hyps_per_beam=num_hyps_per_beam,
            length_normalization=p.length_normalization,
            coverage_penalty=p.coverage_penalty,
            target_seq_length_ratio=p.target_seq_length_ratio,
            eoc_id=p.target_eoc_id,
            merge_paths=p.merge_paths)
        # [num_beams * num_hyps_per_beam, ...].
        max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps
        topk_ids, topk_lens, topk_scores = ops.unpack_hyp(
            tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length)
        # [num_beams, num_hyps_per_beam].
        topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps))

        return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids,
                                      topk_lens, topk_scores, None,
                                      final_other_states)
Exemple #7
0
def MaxPool3D(points, point_features, pooling_idx, closest_idx):
  """Apply max pooling to a point cloud with computed sampling indices.

  sampled_idx and closest_idx are the outputs of a sampler such as
  FurthestPointSampler.

  The pooling operation results in a point cloud with fewer points, where the
  pooled points are specified by pooling_idx. Each element of pooling_idx
  contains an integer in the range [0, P1) containing the index of the point in
  points/points_features.

  Max pooling is performed by assigning each point to its closest pooled point,
  and then taking a max over the features of points assigned. We assume that
  this mapping is provided by closest_idx, where each element should contain
  an integer in the range [0, P2) containing the index of the pooled point that
  each point is assigned to.

  Note: This logic for pooling assumes that there will be at least
  one value > 0 per sampled region for each feature, otherwise it will return 0.
  Additionally, it does a reduce over a masked version of the features, so
  mean and min would not work without a change in the logic.

  Args:
    points: a floating point tf.Tensor with shape [N, P1, 3]
    point_features: a floating point tf.Tensor with shape [N, P1, C]
    pooling_idx: A tf.int32 tf.Tensor of shape [N, P2] with the index of which
      points we want to keep. Each value should be in the range [0, P1].
    closest_idx: A tf.int32 tf.Tensor of shape [N, P1] representing which
      sampled point is closest to each original point. Each value should be in
      the range of [0, P2].

  Returns:
    A tuple of tf.Tensors (pooled_points, pooled_features).

    pooled_points has shape [N, P2, 3] representing the locations of each
    selected point. P2 corresponds to num_pooled_points.

    pooled_features has shape [N, P2, C] representing the pooled features at
    each point.
  """
  batch_size, num_points = py_utils.GetShape(points, 2)
  point_features = py_utils.HasShape(point_features,
                                     [batch_size, num_points, -1])
  pooling_idx = py_utils.HasShape(pooling_idx, [batch_size, -1])
  _, num_output_points = py_utils.GetShape(pooling_idx)
  _, _, feature_dims = py_utils.GetShape(point_features, 3)

  # Gather new point locations.
  pooled_points = tf.array_ops.batch_gather(points, pooling_idx)

  mask = tf.one_hot(closest_idx, num_output_points)  # [N, P1, P2]
  mask = tf.transpose(mask, [2, 0, 1])  # [P2, N, P1]

  def _PartialPoolFeaturesFn(partial_mask):
    partial_mask = tf.tile(
        tf.reshape(partial_mask, [batch_size, num_points, 1]),
        [1, 1, feature_dims])
    # Note: This method of pooling assumes there will be a value > 0
    # And will only work with max under this condition.
    return tf.reduce_max(partial_mask * point_features, axis=1)

  # Performing a map_fn over the pooled points is more memory efficient.
  pooled_point_features = tf.map_fn(_PartialPoolFeaturesFn, mask)  # [P2, N, P1]
  pooled_point_features = tf.transpose(pooled_point_features, [1, 0, 2])

  return pooled_points, pooled_point_features
  def FProp(self,
            theta,
            source_input,
            source_paddings,
            target_input=None,
            target_paddings=None,
            source_segment_id=None,
            target_segment_id=None,
            labels=None,
            label_weights=None,
            source_pos_id=None,
            target_pos_id=None,
            source_task_id=None,
            target_task_id=None):
    """Transforms source sequence of Tensors with Transformers layers.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      source_input:  A sequence of ints indicating source input ids of [time,
        batch] shape or [batch, time] if batch_dim is 0.
      source_paddings: A sequence of 0s and 1s indicating input paddings of
        [time, batch] shape or [batch, time] if batch_dim is 0.
      target_input: A sequence of ints indicating target input ids of [time,
        batch] shape or [batch, time] if batch_dim is 0.
      target_paddings: [target_time, target_batch] or [target_batch,
        target_time] if batch_dim is 0.
      source_segment_id: A sequence of ints indicating source segment ids of
        [time, batch] shape or [batch, time] if batch_dim is 0.
      target_segment_id: A sequence of ints indicating target segment ids of
        [time, batch] shape or [batch, time] if batch_dim is 0.
      labels: A sequence of ints indicating label ids of [time, batch] shape,
        or [batch, time] if batch_dim is 0.
      label_weights: A sequence of floats indicates label weights of [time,
        batch] shape, or [batch, time] if batch_dim is 0.
      source_pos_id: A sequence of ints indicating source position ids of [time,
        batch] shape, or [batch, time] if batch_dim is 0.
      target_pos_id: A sequence of ints indicating target position ids of [time,
        batch] shape, or [batch, time] if batch_dim is 0.
      source_task_id: A sequence of ints indicating source task ids of [time,
        batch] shape, or [batch, time] if batch_dim is 0.
      target_task_id: A sequence of ints indicating target task ids of [time,
        batch] shape, or [batch, time] if batch_dim is 0.

    Returns:
      transformer_output with shape [time, batch, dim] or [batch, time, dim]
      if batch_dim is 0.
    """
    p = self.params
    if p.num_decoder_layers > 0:
      assert target_input is not None
      assert target_paddings is not None
    if p.packed_input:
      assert source_segment_id is not None, (
          'Need to specify src_segment_id if packed input is supported.')
      assert source_pos_id is not None, (
          'Need to specify src_pos_id for packed input and embeddings.')

    logits = super(GPipeTransformerStack,
                   self).FProp(theta, source_input, source_paddings,
                               target_input, target_paddings, source_segment_id,
                               target_segment_id, source_pos_id, target_pos_id,
                               source_task_id, target_task_id)
    if not p.softmax_tpl:
      return logits
    label_weights = tf.reshape(label_weights, [-1])
    target_probs = None
    if p.label_smoothing:
      if p.batch_dim:  # Time-major
        target_probs = tf.transpose(
            self.smoother.FProp(
                theta.smoother,
                tf.transpose(target_paddings),
                tf.transpose(labels),
                target_ids=None), [1, 0, 2])
      else:
        target_probs = self.smoother.FProp(
            theta.smoother, target_paddings, labels, target_ids=None)
      target_probs = tf.reshape(target_probs, [-1, p.softmax_tpl.num_classes])
    reshaped_logits = tf.reshape(logits, [-1, p.softmax_tpl.num_classes])
    tgt_labels = tf.reshape(labels, [-1])
    num_splits = len(p.splits)
    softmax = self.children['cell_{}'.format(num_splits - 1)].softmax
    softmax_theta = theta['cell_{}'.format(num_splits - 1)].softmax
    per_example_xent, _ = softmax.XentLossFromLogits(
        softmax_theta,
        reshaped_logits,
        class_weights=tf.reshape(label_weights, [-1]),
        class_ids=tgt_labels,
        class_probabilities=target_probs)
    xent_shape = tf.shape(logits)[:2]
    per_example_xent = tf.reshape(per_example_xent, xent_shape)
    return per_example_xent, logits
Exemple #9
0
    def _testDecoderFPropFloatHelper(self,
                                     func_inline=False,
                                     num_decoder_layers=1,
                                     target_seq_len=5,
                                     residual_start=0):
        """Computes decoder from params and computes loss with random inputs."""
        cluster = cluster_factory.ForTestingWorker(add_summary=True)
        config = tf.ConfigProto(graph_options=tf.GraphOptions(
            optimizer_options=tf.OptimizerOptions(
                do_function_inlining=func_inline)))
        with cluster, self.session(graph=tf.Graph(),
                                   use_gpu=False,
                                   config=config) as sess:
            tf.set_random_seed(8372749040)
            vn_config = py_utils.VariationalNoiseParams(None, False, False)
            p = self._DecoderParams(vn_config)
            p.rnn_layers = num_decoder_layers
            p.residual_start = residual_start
            p.target_seq_len = target_seq_len
            dec = p.Instantiate()
            src_seq_len = 5
            src_enc = tf.random_normal([src_seq_len, 2, 8], seed=9283748)
            src_enc_padding = tf.constant(
                [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
                dtype=tf.float32)
            encoder_outputs = py_utils.NestedMap(encoded=src_enc,
                                                 padding=src_enc_padding)
            target_ids = tf.transpose(
                tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15],
                             [5, 6, 7, 8], [10, 5, 2, 5]],
                            dtype=tf.int32))
            target_labels = tf.transpose(
                tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13],
                             [5, 7, 8, 10], [10, 5, 2, 4]],
                            dtype=tf.int32))
            target_paddings = tf.transpose(
                tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0],
                             [0, 1, 0, 0], [1, 1, 1, 1]],
                            dtype=tf.float32))
            target_transcripts = tf.constant(
                ['abcd', 'bcde', 'klmp', 'fghi', 'kfcf'])
            target_weights = 1.0 - target_paddings
            targets = py_utils.NestedMap({
                'ids': target_ids,
                'labels': target_labels,
                'weights': target_weights,
                'paddings': target_paddings,
                'transcripts': target_transcripts,
            })
            metrics = dec.FPropDefaultTheta(encoder_outputs, targets).metrics
            loss = metrics['loss'][0]
            correct_predicts = metrics['fraction_of_correct_next_step_preds'][
                0]
            summaries = tf.summary.merge(
                tf.get_collection(tf.GraphKeys.SUMMARIES))

            tf.global_variables_initializer().run()
            loss_v, _ = sess.run([loss, correct_predicts])

            summaries.eval()

            return loss_v
Exemple #10
0
    def Sample(self,
               decoder_theta,
               encoder_outputs,
               random_seed,
               init_state_callback,
               pre_step_callback,
               post_step_callback,
               init_step_ids=None):
        """Samples target sequences, one target sequence per source sequence.

    (Please see beam_search_helper.py for description of decoder callbacks.)

    Args:
      decoder_theta: A NestedMap object containing weights' values of the
        decoder layer and its children layers, to be passed to decoder
        callbacks.
      encoder_outputs: the outputs of the encoder, to be passed to callbacks.
      random_seed: a scalar int32 tensor representing the random seed.
      init_state_callback: decoder._InitBeamSearchStateCallback.
      pre_step_callback: decoder._PreBeamSearchStepCallback.
      post_step_callback: decoder._PostBeamSearchStepCallback.
      init_step_ids: [batch], optional init step ids, default to SOS.

    Returns:
      A NestedMap containing the following tensors

      - 'logits': [batch, max_target_length, vocab_size], representing the
        distribution from which target sequences are sampled.
      - 'ids': [batch, max_target_length] of int32, representing the target
        sequence ids, not including target_sos_id, but maybe ending with
        target_eos_id if end-of-sequence is reached before target_seq_len.
      - 'paddings': [batch, max_target_length] of 0/1, where 1 represents
        a padded timestep.
    """
        p = self.params
        assert p.temperature > 0
        assert p.top_k >= 0
        assert p.num_hyps_per_beam >= 1
        if getattr(encoder_outputs, 'segment_id', 1) is None:
            # Remove None values, which are not supported by recurrent.
            del encoder_outputs['segment_id']
        # init_state_callback may modify 'encoder_outputs', e.g., by inserting
        # 'packed_src'.
        bs_result, bs_state = init_state_callback(decoder_theta,
                                                  encoder_outputs,
                                                  p.num_hyps_per_beam)
        # 'recurrent_theta' represents all cross-timestep information used by the
        # recurrent loop below, including layer theta and encoder outputs.
        recurrent_theta = py_utils.NestedMap(random_seed=random_seed,
                                             encoder_outputs=encoder_outputs)
        batch = tf.shape(bs_result.log_probs)[0]
        recurrent_state0 = py_utils.NestedMap(
            timestep=tf.zeros(shape=[], dtype=tf.int32),
            logits=bs_result.log_probs,
            # Start with target_sos_id.
            ids=init_step_ids if init_step_ids is not None else tf.fill(
                [batch], tf.cast(p.target_sos_id, tf.int32)),
            bs_state=bs_state)

        if p.use_recurrent:
            inputs = py_utils.NestedMap(
                dummy=tf.zeros([p.target_seq_len, batch]))
        else:
            inputs = py_utils.NestedMap(
                ids=tf.TensorArray(dtype=tf.int32, size=p.target_seq_len),
                logits=tf.TensorArray(dtype=bs_result.log_probs.dtype,
                                      size=p.target_seq_len),
            )

        def Step(recurrent_theta, state0, inputs):
            """Computes one decoder step."""
            if p.use_recurrent:
                del inputs
            with tf.name_scope('single_sampler_step'):
                # Compute logits and states.
                bs_result, bs_state1 = pre_step_callback(
                    decoder_theta,
                    recurrent_theta.encoder_outputs,
                    tf.expand_dims(state0.ids, 1),  # [batch, 1].
                    state0.bs_state,
                    num_hyps_per_beam=p.num_hyps_per_beam)
                batch = tf.shape(bs_result.log_probs)[0]
                state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
                state1.logits = bs_result.log_probs

                if p.top_k > 0:
                    topk_logits, topk_ids = tf.math.top_k(state1.logits,
                                                          k=p.top_k)
                    sample_logits = tf.nn.log_softmax(
                        topk_logits) if p.top_k_renormalize else topk_logits
                else:
                    sample_logits = state1.logits

                # Sample ids from logits. [batch].
                ids = tf.reshape(
                    tf.random.stateless_categorical(
                        sample_logits / p.temperature,
                        num_samples=1,
                        seed=tf.stack(
                            [recurrent_theta.random_seed, state0.timestep]),
                        dtype=state0.ids.dtype,
                        name='sample_next_id'), [batch])
                state1.ids = tf.gather(topk_ids, ids, axis=1,
                                       batch_dims=1) if p.top_k > 0 else ids

                if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
                    state1.ids = tf.where(
                        tf.math.logical_and(
                            bs_result.is_last_chunk,
                            tf.equal(state1.ids, p.target_eoc_id)),
                        tf.fill(tf.shape(state1.ids), p.target_eos_id),
                        state1.ids)
                state1.bs_state = post_step_callback(
                    decoder_theta, recurrent_theta.encoder_outputs, state1.ids,
                    bs_state1)
            if p.use_recurrent:
                return state1, py_utils.NestedMap()
            else:
                inputs.ids = inputs.ids.write(state0.timestep, state1.ids)
                inputs.logits = inputs.logits.write(state0.timestep,
                                                    state1.logits)
                return (recurrent_theta, state1, inputs)

        if p.use_recurrent:

            def StopFn(t, theta, state):
                del t, theta  # Unused: this stop function only uses the state ids.
                return tf.equal(state.ids, p.target_eos_id)
        else:

            def StopFn(recurrent_theta, state, inputs):
                del recurrent_theta, inputs
                return tf.logical_not(
                    tf.reduce_all(tf.equal(state.ids, p.target_eos_id)))

        if p.use_stop_fn:
            stop_fn = StopFn
        else:
            stop_fn = None

        if p.use_recurrent:
            accumulated_states, _ = recurrent.Recurrent(
                recurrent_theta,
                recurrent_state0,
                inputs,
                Step,
                stop_fn=stop_fn,
                allow_implicit_capture=True)
        else:
            loop_vars = (recurrent_theta, recurrent_state0, inputs)
            (_, _, accumulated_states) = tf.while_loop(
                StopFn,
                Step,
                loop_vars=loop_vars,
                shape_invariants=_GetShapes(loop_vars, none_shapes=True),
                back_prop=False,
                maximum_iterations=p.target_seq_len)
            accumulated_states.ids = accumulated_states.ids.stack()
            accumulated_states.logits = accumulated_states.logits.stack()

        result = py_utils.NestedMap(logits=tf.transpose(
            accumulated_states.logits, [1, 0, 2]),
                                    ids=tf.transpose(accumulated_states.ids))
        result.paddings = tf.cast(
            _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype)
        # Force ids to be eos_id if the timestep is padded.
        result.ids = tf.where(tf.equal(result.paddings, 0), result.ids,
                              tf.fill(tf.shape(result.ids), p.target_eos_id))
        static_batch_size = bs_result.log_probs.shape[0]
        result.ids.set_shape([static_batch_size, p.target_seq_len])
        result.paddings.set_shape([static_batch_size, p.target_seq_len])
        return result
Exemple #11
0
def expand_tensor(tensor, block_dims):
    """Expands a 2D tensor by replicating the tensor values.

  This is equivalent to the kronecker product of the tensor and a matrix of
  ones of size block_dims.

  Example::

    tensor = [[1,2]
              [3,4]]
    block_dims = [2,2]

    result = [[1 1 2 2]
              [1 1 2 2]
              [3 3 4 4]
              [3 3 4 4]]

  Args:
    tensor: A 2D tensor that needs to be expanded.
    block_dims: List of integers specifying the expansion factor.

  Returns:
    The expanded tensor

  Raises:
    ValueError: if tensor is not rank-2 or block_dims is does not have 2
    elements.
  """
    if tensor.get_shape().ndims != 2:
        raise ValueError('Input tensor must be rank 2')

    if len(block_dims) != 2:
        raise ValueError('block_dims must have 2 elements')

    block_height, block_width = block_dims

    def _tile_rows(tensor, multiple):
        """Create a new tensor by tiling the tensor along rows."""
        return tf.tile(tensor, [multiple, 1])

    def _generate_indices(num_rows, block_dim):
        indices = np.zeros(shape=[num_rows * block_dim, 1], dtype=np.int32)
        for k in range(block_dim):
            for r in range(num_rows):
                indices[k * num_rows + r] = r * block_dim + k
        return indices

    def _replicate_rows(tensor, multiple):
        tensor_shape = tensor.shape.as_list()
        expanded_shape = [tensor_shape[0] * multiple, tensor_shape[1]]
        indices = tf.constant(_generate_indices(tensor_shape[0], multiple))
        return tf.scatter_nd(indices, _tile_rows(tensor, multiple),
                             expanded_shape)

    expanded_tensor = tensor

    # Expand rows by factor block_height.
    if block_height > 1:
        expanded_tensor = _replicate_rows(tensor, block_height)

    # Transpose and expand by factor block_width. Transpose the result.
    if block_width > 1:
        expanded_tensor = tf.transpose(
            _replicate_rows(tf.transpose(expanded_tensor), block_width))

    return expanded_tensor
def flat_beam_search(batch_size,
                     beam_size,
                     max_steps,
                     dec_callback,
                     dec_state,
                     bos_id=1,
                     eos_id=2,
                     length_norm_alpha=0.8,
                     beam_gap=3.0,
                     top_k_fn=tf.math.top_k,
                     prefix=None,
                     prefix_len=None,
                     fprop_dtype=tf.float32,
                     ext_size=0,
                     nbest_size=None,
                     debug=True):
    """Flat beam search.

  Args:
    batch_size: batch size
    beam_size: beam size limit in number of hyps
    max_steps: max steps
    dec_callback: decoder callback (see above)
    dec_state: decoder state
    bos_id: <s> token id
    eos_id: </s> token id
    length_norm_alpha: length normalization parameter
    beam_gap: early stopping threshold; None to disable
    top_k_fn: top_k function to call
    prefix: (optional) int32 tensor [batch_size, prefix_max]
    prefix_len: (optional) int32 tensor [batch_size]
    fprop_dtype: fprop dtype
    ext_size: int >= beam_size, extension buffer size
    nbest_size: number of returned hyps, default is beam_size
    debug: log intermediate vlaues with tpu_summary.tensor()

  Returns:
    (loop_vars, dec_state, nbest) where
    nbest = (topk_ids, topk_len, topk_score)
  """
    assert beam_size > 0
    assert batch_size > 0
    assert max_steps > 0

    buf_size = beam_size * max_steps
    output_len = max_steps

    if prefix is None:
        assert prefix_len is None
        # Create prefix of start tokens.
        prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        prefix += tf.one_hot(beam_size - 1, beam_size, dtype=tf.int32) * bos_id
        prefix_len = tf.ones([batch_size], dtype=tf.int32)
    else:
        assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape)
        assert int(prefix_len.shape[0]) == batch_size, (batch_size,
                                                        prefix_len.shape)
        output_len += int(prefix.shape[1])

    if debug:
        tpu_summary.tensor('prefix', prefix)
        tpu_summary.tensor('prefix_len', prefix_len)

    with tf.name_scope('init_state'):
        t = tf.constant(0)
        tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_id += bos_id
        tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size),
                               buf_size,
                               dtype=fprop_dtype)
        hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype)
        # penalize all hyps except the first
        hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5,
                             dtype=fprop_dtype)
        nbest_size = nbest_size or beam_size
        nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype)
        nbest_score -= 1e9
        nbest_score_norm = nbest_score
        nbest_mask = tf.zeros([batch_size, nbest_size, buf_size],
                              dtype=fprop_dtype)

    with tf.name_scope('init_ext'):
        # Initialize the extension buffer.
        #
        # Extension buffer stores a (potentially large) set of 'extensions',
        # which consist of a hypothesis (represented by ext_mask) and next token
        # (represented by ext_id). At each decoder iteration, top_k extensions
        # from each hypothesis are added to the buffer and sorted by score.
        #
        # Then top beam_size extensions are removed from the buffer and used
        # in the next decoder iteration. And top 'ext_size' remaining extensions
        # are carried over to be possibly evaluated at a later step.
        #
        # As a result of this manipulation, the decoder is no longer restricted
        # to always compare hyps of the same token length at each iteration.
        # In particular, for a fixed length N it can generate more than beam_size
        # terminated hyps.
        #
        # Setting ext_size = 0 disables this feautre.
        if ext_size:
            ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32)
            ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype)
            ext_score -= 1e9
            ext_mask = tf.zeros([batch_size, ext_size, buf_size],
                                dtype=fprop_dtype)
        else:
            ext_size = ext_id = ext_score = ext_mask = 0

    with tf.name_scope('init_prefix'):
        # rename prefix->pfx for shorter variables
        pfx = tf.cast(prefix, tf.int32)
        pfx_len = tf.cast(prefix_len, tf.int32)
        del prefix, prefix_len
        # Before the first call to dec_callback() the prefix shall be packed into
        # the tgt_id buffer as follows:
        #
        # [ - - - - - - P P P P P P P* - - - ]   ^
        # [ - - P P P P P P P P P P P* - - - ]   | batch
        # [ - - - - - - - - - - - P P* - - - ]   V
        # |<---- prefix len ---->  |<-- beam -->
        #
        # The last meaningful token in the prefix (P*)
        # must be located at the same position in all batch rows.
        #
        # We then make one dec_callback() with full prefix (minus P*)
        # which will populate the initial dec_state
        # (for transformer -- self-attention key/value cache)
        #
        # The last block [batch, beam] then becomes the first tgt_id for the loop.
        pfx_max = int(pfx.shape[1])
        pfx_mul = pfx_max // beam_size
        assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size)
        pfx_time = tf.range(pfx_max)
        pfx_indexes = pfx_time - pfx_max + tf.expand_dims(pfx_len - 1, 1)
        pfx_pad = tf.cast(tf.greater_equal(pfx_indexes, 0),
                          tf.int32)  # Exclude final pfx token.
        pfx_id = tf.roll(pfx, shift=1, axis=-1) * pfx_pad
        pfx_last = pfx[:, -1]

        buf_time = tf.range(buf_size)
        pfx_time_mask = tf.cast(
            tf.less_equal(tf.expand_dims(buf_time, 0),
                          tf.expand_dims(pfx_time, 1)), fprop_dtype)
        pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype),
                             pfx_time_mask)
        # Remove padding.
        assert buf_size > pfx_max
        pfx_pad_long = tf.pad(pfx_pad, [(0, 0), (0, buf_size - pfx_max)],
                              constant_values=1)
        pfx_mask *= tf.cast(tf.expand_dims(pfx_pad_long, axis=1), tf.float32)
        pfx_segment_id = pfx_pad
        pfx_pos = pfx_indexes * pfx_pad

        if debug:
            tpu_summary.tensor('pfx_id', pfx_id)
            tpu_summary.tensor('pfx_len', pfx_len)
            tpu_summary.tensor('pfx_pos', pfx_pos)
            tpu_summary.tensor('pfx_last', pfx_last)

        # Now call decoder with prefix minus P*:
        # 'dec_state' now shall contain the key/value cache for prefix tokens
        # (for transformer models), and 'logits' we can either discard or
        # roll into the initial hyp_score. Discard is simpler.
        with tf.name_scope('prefix_fprop'):
            # TODO(krikun): remove extra type checks
            assert (pfx_id.dtype == tf.int32), (pfx_id.dtype)
            assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype)
            assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype)
            assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype)
            assert (t.dtype == tf.int32), (t.dtype)
            logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos,
                                             pfx_mask, dec_state, t)
            del logits

        # Now construct the initial state for the rest of the beam search loop.
        # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape
        # 'tgt_pos' is different for each batch row and is equal to prefix_len
        # 'tgt_segment_id' always 1 (no packing)
        # 'hyp_score' is 0 for beam=0 and negative for beam>=1
        tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            pfx_last, 1)
        tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            (pfx_len - 1), 1)
        hyp_score = tf.zeros(
            [batch_size, beam_size], dtype=fprop_dtype) - tf.cast(
                tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype)

        # TODO(krikun) Here we make initial 't' constant and determined by the
        # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic
        # as t ~  max(pfx_len) / beam_size and this will more steps for beam search
        # however 'max' results in a very slow all-to-all for 'max' on 16x16
        # and variable number of decoder steps may result in bad latency.
        t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32)

        # Initial tgt_mask is such that each token P* has attention on itself
        # (as usual) and on all prefix tokens before it, which are not padding.
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.cast(
            tf.expand_dims(
                tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1),
            fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                               buf_size,
                               dtype=fprop_dtype)

        if debug:
            tpu_summary.tensor('tgt_id', tgt_id)
            tpu_summary.tensor('tgt_pos', tgt_pos)
            tpu_summary.tensor('tgt_mask', tgt_mask)
            tpu_summary.tensor('t', t)

    with tf.name_scope('init_hist'):
        # h_tgt_id is used to recover topk_ids from nbest_mask
        h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps)
        h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps)

        # When non-trivial prefix is present we also write prefix ids to
        # h_tgt_id so that the full sequence including prefix can be recovered
        # by unmask() below.  When prefix is empty, pfx_id shape is [batch, 0]
        # and the loop below becomes a no-op.
        # TODO(krikun): maybe a tf.while_loop is more appropriate here.
        for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)):
            h_tgt_id = h_tgt_id.write(i, x_i)
        for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)):
            h_tgt_pos = h_tgt_pos.write(i, x_i)

        hist = (h_tgt_id, h_tgt_pos)
        tf.logging.info('hist=%r', hist)

    nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm)
    tf.logging.info('nbest_hyps=%r', nbest_hyps)

    ext = (ext_id, ext_score, ext_mask)
    tf.logging.info('ext=%r', ext)

    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)
    tf.logging.info('loop_vars=%r', loop_vars)

    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state

    def loop_cond(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        if beam_gap is None:
            (t, _, _, _, _, _, _, _) = loop_vars
            return t < max_steps
        else:
            (t, _, _, _, _, nbest_hyps, _, _) = loop_vars
            (_, nbest_score, _) = nbest_hyps
            # stop early if all current hyps are significantly worse than nbest
            diff = tf.reduce_min(
                tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1))
            return tf.math.logical_and(t < max_steps, diff < beam_gap)

    with tf.name_scope('flat_beam_search_loop'):
        (loop_vars, dec_state) = tf.while_loop(loop_cond,
                                               loop_step,
                                               loop_vars=(loop_vars,
                                                          dec_state),
                                               back_prop=False,
                                               swap_memory=False,
                                               maximum_iterations=max_steps)

    # flatten all tensorarrays into tensors
    (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
     hist) = loop_vars
    (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps
    (h_tgt_id, h_tgt_pos) = hist
    h_tgt_id = h_tgt_id.stack()
    h_tgt_pos = h_tgt_pos.stack()
    hist = (h_tgt_id, h_tgt_pos)
    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)

    # recover topk_ids from nbest_mask and tgt_id history
    h = tf.transpose(h_tgt_id, [1, 0, 2])
    h = tf.reshape(h, [batch_size, buf_size])

    def unmask(h, m):
        with tf.name_scope('unmask'):
            tpu_summary.tensor('unmask_h', h)
            tpu_summary.tensor('unmask_m', m)
            t = tf.cumsum(m, -1) * m - 1
            mh = einsum_i32('bkt,bt->bkt', m, h)
            t2 = tf.one_hot(tf.cast(t, tf.int32),
                            output_len,
                            dtype=fprop_dtype)
            x = einsum_i32('bkt,bktT->bkT', mh, t2)
            return tf.cast(x, h.dtype)

    topk_ids = unmask(h, nbest_mask)
    topk_len = tf.reduce_sum(nbest_mask, -1)
    topk_len = tf.cast(topk_len, tf.int32)
    # add eos, because nbest_mask does not encode eos
    topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32)
    topk_len += 1
    topk_len = tf.minimum(topk_len, output_len)
    topk_score = nbest_score_norm

    nbest = (topk_ids, topk_len, topk_score)

    return loop_vars, dec_state, nbest
Exemple #13
0
  def FProp(self, theta, input_batch):
    """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task
            ids of shape [batch, time].

    Returns:
      A NestedMap containing

      - encoded: The encoded features, either a tensor of shape
        [time, batch, depth], or a list of tensors if is_transparent is set in
        transformer_stack.
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
      - embedded_inputs: [time, batch, depth] embedded inputs tokens without
        positional encodings.
    """

    p = self.params
    with tf.name_scope(p.name):
      src_segment_id = None
      src_segment_pos = None
      input_ids = py_utils.with_dependencies([
          py_utils.assert_shape_match(
              tf.shape(input_batch.ids), tf.shape(input_batch.paddings)),
          py_utils.assert_equal(tf.rank(input_batch.ids), 2)
      ], input_batch.ids)

      if (not py_utils.use_tpu() and
          FLAGS.transformer_encoder_truncates_inputs):
        max_seq_length = tf.cast(
            tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)),
            tf.int32)
        paddings = py_utils.with_dependencies([
            py_utils.assert_equal(
                tf.constant(True, tf.bool),
                tf.reduce_all(input_batch.paddings[:, max_seq_length:] > 0.5))
        ], input_batch.paddings)
        input_ids = input_ids[:, :max_seq_length]
        paddings = paddings[:, :max_seq_length]
        if p.packed_input:
          src_segment_id = input_batch.segment_ids[:, :max_seq_length]
          src_segment_pos = input_batch.segment_pos[:, :max_seq_length]
      else:
        paddings = input_batch.paddings
        if p.packed_input:
          src_segment_id = input_batch.segment_ids
          src_segment_pos = input_batch.segment_pos

      max_time = tf.shape(input_ids)[1]

      # Input token embeddings + positional embeddings
      if not p.shared_emb:
        input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                              tf.reshape(input_ids, [-1]))
      else:
        input_embs = self.softmax.EmbLookup(theta.softmax,
                                            tf.reshape(input_ids, [-1]))

      input_embs = tf.reshape(input_embs,
                              [-1, max_time, p.token_emb.embedding_dim])
      # [time, batch, dim]
      orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

      if p.packed_input:
        position_embs = self.position_emb.FPropWithPosition(
            theta.position_emb, src_segment_pos)
      else:
        position_embs = self.position_emb.FProp(theta.position_emb, max_time)
        position_embs = tf.reshape(position_embs,
                                   [1, max_time, p.token_emb.embedding_dim])
      # Position embeddings are simply added to token embeddings.
      input_embs += position_embs

      if p.individually_tagged_input:
        assert not p.packed_input
        # Look up tag embeddings; this assumes that the tags arriving on
        # input_batch.segment_ids (originating as common.source_segment_id
        # in the input NMTExample) have been reserved in the WPM vocabulary
        # as context tags, e.g. the ids for <src_token> and <ctxt_token> in
        # wide source context experiments.
        input_tags = py_utils.with_dependencies([
            py_utils.assert_shape_match(
                tf.shape(input_batch.segment_ids), tf.shape(input_batch.ids)),
            py_utils.assert_equal(tf.rank(input_batch.segment_ids), 2)
        ], input_batch.segment_ids)
        tag_embeddings = self.token_emb.EmbLookup(theta.token_emb,
                                                  tf.reshape(input_tags, [-1]))
        tag_embeddings = tf.reshape(tag_embeddings,
                                    [-1, max_time, p.token_emb.embedding_dim])
        # Concatenate the tag embeddings to the input embeddings, and then
        # project back to the original embedding dimensionality.
        concat_embs = tf.concat([input_embs, tag_embeddings], -1)
        input_embs = self.concat_emb_and_tag_proj.FProp(
            theta.concat_emb_and_tag_proj, concat_embs)

      if p.ln_input:
        input_embs = self.layer_norm_input.FProp(theta.layer_norm_input,
                                                 input_embs)

      if p.task_emb:
        input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                              input_batch.task_ids)

      summary_utils.histogram('input_embs', input_embs)
      if p.model_dim != p.token_emb.embedding_dim:
        input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)
        summary_utils.histogram('emb_proj', input_embs)

      paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
      if p.packed_input:
        src_segment_id = tf.transpose(src_segment_id)
      input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs)

      # [time, batch, dim]
      transformer_input = tf.transpose(input_embs, [1, 0, 2])

    if not self.do_eval and p.apply_source_mask:
      # Augment padding for masked source word positions.
      dtype = paddings.dtype
      source_mask = tf.where(
          tf.equal(input_ids, p.source_mask_id),
          tf.ones_like(input_ids, dtype=dtype),
          tf.zeros_like(input_ids, dtype=dtype))
      # Make sure padding is between 0 and 1.
      paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0,
                                  1.0)

    encoded, padding, segment_id = self.transformer_stack.FProp(
        theta.transformer_stack, transformer_input, paddings, src_segment_id)
    return py_utils.NestedMap(
        encoded=encoded,
        padding=padding,
        segment_id=segment_id,
        embedded_inputs=orig_input_embs)
Exemple #14
0
def _SingleClassDecodeWithNMS(predicted_bboxes,
                              classification_scores,
                              nms_iou_threshold,
                              score_threshold,
                              max_boxes_per_class=None):
    """Perform NMS on predicted bounding boxes / associated logits.

  Args:
    predicted_bboxes: [batch_size, num_boxes, 7] float Tensor containing
      predicted bounding box coordinates.
    classification_scores: [batch_size, num_boxes, num_classes] float Tensor
      containing predicted classification scores for each box.
    nms_iou_threshold: IoU threshold to use when determining whether two boxes
      overlap for purposes of suppression.
    score_threshold: The score threshold passed to NMS that allows NMS to
      quickly ignore irrelevant boxes.
    max_boxes_per_class: The maximum number of boxes per example to emit. If
      None, this value is set to num_boxes from the shape of predicted_bboxes.

  Returns:
    predicted_bboxes: Filtered bboxes after NMS of shape
      [batch_size, num_classes, max_boxes_per_class, 7].
    bbox_scores: A float32 Tensor with the score for each box of shape
      [batch_size, num_classes, max_boxes_per_class].
    valid_mask: A float32 Tensor with 1/0 values indicating the validity of
      each box. 1 indicates valid, and 0 invalid. Tensor of shape
      [batch_size, num_classes, max_boxes_per_class].
  """
    utils_3d = detection_3d_lib.Utils3D()
    predicted_bboxes = py_utils.HasShape(predicted_bboxes, [-1, -1, 7])
    batch_size, num_predicted_boxes, _ = py_utils.GetShape(predicted_bboxes)
    classification_scores = py_utils.HasShape(
        classification_scores, [batch_size, num_predicted_boxes, -1])
    _, _, num_classes = py_utils.GetShape(classification_scores)

    if not isinstance(nms_iou_threshold, float):
        raise ValueError('Single class NMS only supports a scalar '
                         '`nms_iou_threshold`.')
    if not isinstance(score_threshold, float):
        raise ValueError('Single class NMS only supports a scalar '
                         '`score_threshold`.')

    if max_boxes_per_class is None:
        max_boxes_per_class = num_predicted_boxes

    # TODO(jngiam): Change to be per-class bboxes, and hence, per-class NMS, and
    # per-class thresholding.
    # [batch, num_predicted_boxes]
    nms_scores = tf.reduce_max(classification_scores, axis=-1)

    # Compute the most likely label by computing the highest class score from
    # the output of the sigmoid.
    likely_labels = tf.argmax(classification_scores, axis=-1)

    # When background is the most likely class for the box, mask out the scores
    # of that box from NMS scoring so the background boxes don't dominate the
    # NMS.
    nms_scores *= tf.cast(likely_labels > 0, tf.float32)

    # Compute NMS for every sample in the batch.
    nms_indices, valid_mask = utils_3d.BatchedNMSIndices(
        predicted_bboxes,
        nms_scores,
        nms_iou_threshold=nms_iou_threshold,
        score_threshold=score_threshold,
        max_num_boxes=max_boxes_per_class)

    # Reorder the box data and logits according to NMS scoring.
    predicted_bboxes = tf.array_ops.batch_gather(predicted_bboxes, nms_indices)
    classification_scores = tf.array_ops.batch_gather(classification_scores,
                                                      nms_indices)

    # Now reformat the output of NMS to match the format of the
    # MultiClassOrientedDecodeWithNMS, which outputs a per class NMS result.
    # This takes the leading shape of
    # [batch_size, num_classes, max_boxes_per_class] for all outputs, which
    # means since this NMS is not class specific we need to tile the outputs
    # num_classes times or reorder the data such that its [batch, num_classes].
    predicted_bboxes = tf.tile(predicted_bboxes[:, tf.newaxis, :, :],
                               [1, num_classes, 1, 1])
    classification_scores = tf.transpose(classification_scores, (0, 2, 1))
    classification_scores = py_utils.HasShape(
        classification_scores, [batch_size, num_classes, max_boxes_per_class])
    valid_mask = tf.tile(valid_mask[:, tf.newaxis, :], [1, num_classes, 1])
    return predicted_bboxes, classification_scores, valid_mask
Exemple #15
0
  def FProp(self, theta, *args):
    """Run multiple cells in different devices in a pipelining manner.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: Non-keyworded variable length argument list of input tensors.

    Returns:
      A list of output tensors
    """
    # TODO(huangyp): handle optional None inputs.
    p = self.params
    if p.is_eval:
      outputs = _ToTuple(args)
      for (name, l) in self._before_layers:
        outputs = _ToTuple(outputs)
        outputs = l.FProp(theta[name], *outputs)
      for (name, l) in self._cells:
        outputs = _ToTuple(outputs)
        outputs = l.FProp(theta[name], *outputs)
      return outputs

    num_cells = len(p.cell_tpl)
    cluster = self.cluster

    # Compute shapes of input and output tenors.
    input_tenors = _ToTuple(args)
    mini_batch_size = input_tenors[0].get_shape().as_list()[p.batch_dim]
    if p.state_dtype:
      state_dtype = p.state_dtype
    else:
      state_dtype = input_tenors[0].dtype
    if p.num_micro_batches > mini_batch_size:
      p.num_micro_batches = mini_batch_size
    micro_batch_size = mini_batch_size // p.num_micro_batches

    input_shapes = ()
    for input_tensor in input_tenors:
      if input_tensor is not None:
        input_shape = input_tensor.get_shape().as_list()
        input_shape[p.batch_dim] = micro_batch_size
        input_shapes += (tf.TensorShape(input_shape),)
      else:
        input_shapes += (None,)

    state_shapes = self._CalculateOutputShapes(input_shapes)

    def GetCellFn(i):
      """Get the ith feature extraction layer."""

      def CellFn(theta, state0, inputs):
        """A cell fn is exectued inside of StackedRecurrent."""
        del state0
        frop_inputs = []
        for input_idx in range(len(state_shapes[i])):
          name = 's{}'.format(input_idx)
          if state_shapes[i][input_idx] is not None:
            inputs[name].set_shape(state_shapes[i][input_idx])
            frop_inputs.append(inputs[name])
          else:
            frop_inputs.append(None)

        with CellFnFropOpReplacementWrapper():
          tf.logging.info('cell {} input {}'.format(i, frop_inputs))
          mb_tensor = inputs[_MICRO_BATCH_STATE_NAME]
          SetOverWriteGlobalStep(mb_tensor)
          _, cell = self._cells[i]
          outputs = cell.FProp(theta, *frop_inputs)

        state1 = py_utils.NestedMap()
        state1[_MICRO_BATCH_STATE_NAME] = mb_tensor
        outputs = _ToTuple(outputs)
        assert len(outputs) == len(state_shapes[i + 1])
        for output_idx in range(len(outputs)):
          if outputs[output_idx] is not None:
            name = 's{}'.format(output_idx)
            state1[name] = outputs[output_idx]
        return state1, py_utils.NestedMap()

      return CellFn

    cell_fns = []
    accumulator_layers = []
    thetas = []
    init_states = []
    devices = []
    for cell_idx in range(num_cells):
      cell_name, cell = self._cells[cell_idx]
      accumulator_layers.append(cell)
      cell_fns.append(GetCellFn(cell_idx))
      thetas.append(theta[cell_name])
      init_state = py_utils.NestedMap()
      init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype)
      for output_idx in range(len(state_shapes[cell_idx + 1])):
        name = 's{}'.format(output_idx)
        if state_shapes[cell_idx + 1][output_idx] is not None:
          init_state[name] = tf.zeros(
              state_shapes[cell_idx + 1][output_idx], dtype=state_dtype)
      init_states.append(init_state)
      devices.append(cluster.WorkerDeviceInModelSplit(cell_idx))

    cell_grads = [None] * num_cells
    cell_outs = [lambda x: x] * num_cells
    cell_out_grads = [lambda x: x] * num_cells

    with tf.device(devices[0]):
      previous = input_tenors
      for (name, l) in self._before_layers:
        previous = l.FProp(theta[name], *previous)
        previous = _ToTuple(previous)
      inputs = py_utils.NestedMap()
      gs_tensor = py_utils.GetGlobalStep()
      inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([
          tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype)
          for t in range(p.num_micro_batches)
      ])

      # TODO(huangyp, dehao): apply dehao's trick to reshape the input tensor
      # to [p.num_micro_batches, -1, 128].
      for output_idx, output_tenor in enumerate(previous):
        name = 's{}'.format(output_idx)
        if output_tenor is not None:
          output_tenor = tf.stack(
              tf.split(output_tenor, p.num_micro_batches, axis=p.batch_dim))
          inputs[name] = output_tenor

    output, _ = recurrent.StackedRecurrent(
        devices=devices,
        cell_fns=cell_fns,
        cell_grads=cell_grads,
        cell_outs=cell_outs,
        cell_out_grads=cell_out_grads,
        thetas=thetas,
        init_states=init_states,
        inputs=inputs,
        accumulator_layers=accumulator_layers,
        unused_acc_state=True)

    with tf.device(devices[-1]):
      output_tensors = []
      for output_idx in range(len(state_shapes[-1])):
        state_shape = state_shapes[-1][output_idx]
        if state_shape is None:
          output_tensors.append(None)
          continue
        output_name = 's{}'.format(output_idx)
        output_tensor = output[output_name]
        if p.batch_dim != 0:
          perm = list(range(1, p.batch_dim + 1)) + [0]
          perm += list(range(p.batch_dim + 1, len(state_shape) + 1))
          output_tensor = tf.transpose(output_tensor, perm=perm)
        state_shape[p.batch_dim] *= p.num_micro_batches
        output_tensor = tf.reshape(output_tensor, state_shape)
        output_tensors.append(output_tensor)
      tf.logging.info('pipeline output = {}'.format(output_tensors))
      if len(output_tensors) == 1:
        return output_tensors[0]
      return tuple(output_tensors)
Exemple #16
0
    def FProp(self, theta, batch, state0=None):
        """Encodes source as represented by 'inputs' and 'paddings'.

    Args:
      theta: A NestedMap object containing weights' values of this
        layer and its children layers.
      batch: A NestedMap with fields:

        - src_inputs - The inputs tensor. It is expected to be of shape [batch,
          time, feature_dim, channels].
        - paddings - The paddings tensor. It is expected to be of shape [batch,
          time].
      state0: Recurrent input state. Not supported/ignored by this encoder.

    Returns:
      A NestedMap containing

      - 'encoded': a feature tensor of shape [time, batch, depth]
      - 'padding': a 0/1 tensor of shape [time, batch]
      - 'state': the updated recurrent state
      - '${layer_type}_${layer_index}': The per-layer encoder output. Each one
        is a NestedMap containing 'encoded' and 'padding' similar to regular
        final outputs, except that 'encoded' from conv or conv_lstm layers are
        of shape [time, batch, depth, channels].
    """
        p = self.params
        inputs, paddings = batch.src_inputs, batch.paddings
        outputs = py_utils.NestedMap()
        with tf.name_scope(p.name):
            # Adding specAugmentation.
            if p.use_specaugment and not self.do_eval:
                inputs, paddings = self.specaugment.FProp(
                    theta.specaugment, inputs, paddings)
            # Add a few extra padded timesteps at the end. This is for ensuring the
            # correctness of the conv-layers at the edges.
            if p.pad_steps > 0:
                # inplace_update() is not supported by TPU for now. Since we have done
                # padding on the input_generator, we may avoid this additional padding.
                assert not py_utils.use_tpu()
                inputs_pad = tf.zeros(
                    inplace_ops.inplace_update(tf.shape(inputs), 1,
                                               p.pad_steps), inputs.dtype)
                paddings_pad = tf.ones(
                    inplace_ops.inplace_update(tf.shape(paddings), 1,
                                               p.pad_steps), paddings.dtype)
                inputs = tf.concat([inputs, inputs_pad], 1, name='inputs')
                paddings = tf.concat([paddings, paddings_pad], 1)

            plots = [
                summary_utils.PrepareSequenceForPlot(
                    tf.transpose(inputs, [0, 1, 3, 2]), paddings, 'inputs')
            ]

            conv_out = inputs
            out_padding = paddings
            for i, conv_layer in enumerate(self.conv):
                conv_out, out_padding = conv_layer.FProp(
                    theta.conv[i], conv_out, out_padding)
                if p.extra_per_layer_outputs:
                    conv_out *= (1.0 -
                                 out_padding[:, :, tf.newaxis, tf.newaxis])
                    outputs['conv_%d' % i] = py_utils.NestedMap(
                        encoded=tf.transpose(conv_out,
                                             [1, 0, 2, 3]),  # to [t, b, d, c]
                        padding=tf.transpose(out_padding))
                plots.append(
                    summary_utils.PrepareSequenceForPlot(
                        tf.transpose(conv_out, [0, 1, 3, 2]), out_padding,
                        'conv_%d_out' % i))

            def TransposeFirstTwoDims(t):
                first_dim = tf.shape(t)[0]
                second_dim = tf.shape(t)[1]
                t_new = tf.transpose(
                    tf.reshape(t, [first_dim, second_dim, -1]), [1, 0, 2])
                t_shape_new = tf.concat([[second_dim], [first_dim],
                                         tf.shape(t)[2:]], 0)
                return tf.reshape(t_new, t_shape_new)

            # Now the conv-lstm part.
            conv_lstm_out = conv_out
            conv_lstm_out_padding = out_padding
            for i, (rnn, cnn) in enumerate(
                    zip(self.conv_lstm_rnn, self.conv_lstm_cnn)):
                conv_lstm_in = conv_lstm_out
                # Move time dimension to be the first.
                conv_lstm_in = TransposeFirstTwoDims(conv_lstm_in)
                conv_lstm_in = tf.expand_dims(conv_lstm_in, 2)
                conv_lstm_in_padding = tf.expand_dims(
                    tf.transpose(conv_lstm_out_padding), 2)
                lstm_out = rnn.FProp(theta.conv_lstm_rnn[i], conv_lstm_in,
                                     conv_lstm_in_padding)
                # Move time dimension to be the second.
                cnn_in = TransposeFirstTwoDims(lstm_out)
                cnn_in = tf.squeeze(cnn_in, 2)
                cnn_in_padding = conv_lstm_out_padding
                cnn_out, cnn_out_padding = cnn.FProp(theta.conv_lstm_cnn[i],
                                                     cnn_in, cnn_in_padding)
                conv_lstm_out, conv_lstm_out_padding = cnn_out, cnn_out_padding
                if p.extra_per_layer_outputs:
                    conv_lstm_out *= (
                        1.0 -
                        conv_lstm_out_padding[:, :, tf.newaxis, tf.newaxis])
                    outputs['conv_lstm_%d' % i] = py_utils.NestedMap(
                        encoded=tf.transpose(conv_lstm_out,
                                             [1, 0, 2, 3]),  # to [t, b, d, c]
                        padding=tf.transpose(conv_lstm_out_padding))
                plots.append(
                    summary_utils.PrepareSequenceForPlot(
                        conv_lstm_out, conv_lstm_out_padding,
                        'conv_lstm_%d_out' % i))

            # Need to do a reshape before starting the rnn layers.
            conv_lstm_out = py_utils.HasRank(conv_lstm_out, 4)
            conv_lstm_out_shape = tf.shape(conv_lstm_out)
            new_shape = tf.concat([conv_lstm_out_shape[:2], [-1]], 0)
            conv_lstm_out = tf.reshape(conv_lstm_out, new_shape)
            if self._first_lstm_input_dim_pad:
                conv_lstm_out = tf.pad(
                    conv_lstm_out,
                    [[0, 0], [0, 0], [0, self._first_lstm_input_dim_pad]])

            conv_lstm_out = py_utils.HasShape(
                conv_lstm_out, [-1, -1, self._first_lstm_input_dim])

            # Transpose to move the time dimension to be the first.
            rnn_in = tf.transpose(conv_lstm_out, [1, 0, 2])
            rnn_padding = tf.expand_dims(tf.transpose(conv_lstm_out_padding),
                                         2)
            # rnn_in is of shape [time, batch, depth]
            # rnn_padding is of shape [time, batch, 1]

            # Now the rnn layers.
            num_skips = 0
            for i in range(p.num_lstm_layers):
                rnn_out = self.rnn[i].FProp(theta.rnn[i], rnn_in, rnn_padding)
                residual_index = i - p.residual_start + 1
                if p.residual_start > 0 and residual_index >= 0:
                    if residual_index % p.residual_stride == 0:
                        residual_in = rnn_in
                    if residual_index % p.residual_stride == p.residual_stride - 1:
                        # Highway skip connection.
                        if p.highway_skip:
                            rnn_out = self.highway_skip[num_skips].FProp(
                                theta.highway_skip[num_skips], residual_in,
                                rnn_out)
                            num_skips += 1
                        else:
                            # Residual skip connection.
                            rnn_out += py_utils.HasShape(
                                residual_in, tf.shape(rnn_out))
                if p.project_lstm_output and (i < p.num_lstm_layers - 1):
                    # Projection layers.
                    rnn_out = self.proj[i].FProp(theta.proj[i], rnn_out,
                                                 rnn_padding)
                if i == p.num_lstm_layers - 1:
                    rnn_out *= (1.0 - rnn_padding)
                if p.extra_per_layer_outputs:
                    rnn_out *= (1.0 - rnn_padding)
                    outputs['rnn_%d' % i] = py_utils.NestedMap(
                        encoded=rnn_out, padding=tf.squeeze(rnn_padding, [2]))
                # Stacking layer connection.
                if p.layer_index_before_stacking == i:
                    # Stacking layer expects input tensor shape as [batch, time, feature].
                    # So transpose the tensors before and after the layer.
                    rnn_out, rnn_padding = self.stacking.FProp(
                        tf.transpose(rnn_out, [1, 0, 2]),
                        tf.transpose(rnn_padding, [1, 0, 2]))
                    rnn_out = tf.transpose(rnn_out, [1, 0, 2])
                    rnn_padding = tf.transpose(rnn_padding, [1, 0, 2])

                plots.append(
                    summary_utils.PrepareSequenceForPlot(
                        tf.transpose(rnn_out, [1, 0, 2]),
                        tf.transpose(rnn_padding, [1, 0, 2]),
                        'rnn_%d_out' % i))
                rnn_in = rnn_out
            final_out = rnn_in

            summary_utils.PlotSequenceFeatures(list(reversed(plots)),
                                               'encoder_example',
                                               xlabel='Time')

            outputs['encoded'] = final_out
            outputs['padding'] = tf.squeeze(rnn_padding, [2])
            outputs['state'] = py_utils.NestedMap()
            return outputs
Exemple #17
0
 def _process(record):
     num = tf.py_func(pickle.loads, [record], tf.int32)
     bucket_key = tf.shape(num)[0]
     return [num, tf.transpose(num, [1, 0, 2])], bucket_key
Exemple #18
0
 def Transpose(paddings):
     paddings = paddings if isinstance(paddings, list) else [paddings]
     return [tf.transpose(p) for p in paddings]
    def Sample(self, decoder_theta, encoder_outputs, random_seed,
               init_state_callback, pre_step_callback, post_step_callback):
        """Samples target sequences, one target sequence per source sequence.

    (Please see beam_search_helper.py for description of decoder callbacks.)

    Args:
      decoder_theta: A NestedMap object containing weights' values of the
        decoder layer and its children layers, to be passed to decoder
        callbacks.
      encoder_outputs: the outputs of the encoder, to be passed to callbacks.
      random_seed: a scalar int32 tensor representing the random seed.
      init_state_callback: decoder._InitBeamSearchStateCallback.
      pre_step_callback: decoder._PreBeamSearchStepCallback.
      post_step_callback: decoder._PostBeamSearchStepCallback.

    Returns:
      A NestedMap containing the following tensors

      - 'logits': [batch, max_target_length, vocab_size], representing the
        distribution from which target sequences are sampled.
      - 'ids': [batch, max_target_length] of int32, representing the target
        sequence ids, not including target_sos_id, but maybe ending with
        target_eos_id if end-of-sequence is reached before target_seq_len.
      - 'paddings': [batch, max_target_length] of 0/1, where 1 represents
        a padded timestep.
    """
        p = self.params
        assert p.temperature > 0
        if getattr(encoder_outputs, 'segment_id', 1) is None:
            # Remove None values, which are not supported by recurrent.
            del encoder_outputs['segment_id']
        # init_state_callback may modify 'encoder_outputs', e.g., by inserting
        # 'packed_src'.
        bs_result, bs_state = init_state_callback(decoder_theta,
                                                  encoder_outputs,
                                                  num_hyps_per_beam=1)
        # 'recurrent_theta' represents all cross-timestep information used by the
        # recurrent loop below, including layer theta and encoder outputs.
        recurrent_theta = py_utils.NestedMap(theta=decoder_theta,
                                             random_seed=random_seed,
                                             encoder_outputs=encoder_outputs)
        batch = tf.shape(bs_result.log_probs)[0]
        recurrent_state0 = py_utils.NestedMap(
            timestep=tf.zeros(shape=[], dtype=tf.int32),
            logits=bs_result.log_probs,
            # Start with target_sos_id.
            ids=tf.fill([batch], tf.cast(p.target_sos_id, tf.int32)),
            bs_state=bs_state)
        inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch]))

        def Step(recurrent_theta, state0, inputs):
            """Computes one decoder step."""
            del inputs
            with tf.name_scope('single_sampler_step'):
                # Compute logits and states.
                bs_result, bs_state1 = pre_step_callback(
                    recurrent_theta.theta,
                    recurrent_theta.encoder_outputs,
                    tf.expand_dims(state0.ids, 1),  # [batch, 1].
                    state0.bs_state,
                    num_hyps_per_beam=1)
                batch = tf.shape(bs_result.log_probs)[0]
                state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
                state1.logits = bs_result.log_probs
                # Sample ids from logits. [batch].
                state1.ids = tf.reshape(
                    tf.random.stateless_categorical(
                        state1.logits / p.temperature,
                        num_samples=1,
                        seed=tf.stack(
                            [recurrent_theta.random_seed, state0.timestep]),
                        dtype=state0.ids.dtype,
                        name='sample_next_id'), [batch])
                if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
                    state1.ids = tf.where(
                        tf.math.logical_and(
                            bs_result.is_last_chunk,
                            tf.equal(state1.ids, p.target_eoc_id)),
                        tf.fill(tf.shape(state1.ids), p.target_eos_id),
                        state1.ids)
                state1.bs_state = post_step_callback(
                    recurrent_theta.theta, recurrent_theta.encoder_outputs,
                    state1.ids, bs_state1)
            return state1, py_utils.NestedMap()

        accumulated_states, _ = recurrent.Recurrent(
            recurrent_theta,
            recurrent_state0,
            inputs,
            Step,
            allow_implicit_capture=True)
        result = py_utils.NestedMap(logits=tf.transpose(
            accumulated_states.logits, [1, 0, 2]),
                                    ids=tf.transpose(accumulated_states.ids))
        result.paddings = tf.cast(
            _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype)
        # Force ids to be eos_id if the timestep is padded.
        result.ids = tf.where(tf.equal(result.paddings, 0), result.ids,
                              tf.fill(tf.shape(result.ids), p.target_eos_id))
        static_batch_size = bs_result.log_probs.shape[0]
        result.ids.set_shape([static_batch_size, p.target_seq_len])
        result.paddings.set_shape([static_batch_size, p.target_seq_len])
        return result
    def _BeamSearchStep(self, theta, encoder_outputs, cur_step, step_ids,
                        core_bs_states, other_states, num_hyps_per_beam,
                        pre_beam_search_step_callback,
                        post_beam_search_step_callback):
        """Extend beam search hyps for one step.

      | num_beams = Number of source sequences to be decoded.
      | num_hyps_per_beam = Number of hyps to keep per source sequence.
      | num_hyps = num_beams * num_hyps_per_beam
      | src_seq_len = Number of time steps in the source sequence.
      | src_batch = Number of examples in the source sequence.
      | tgt_seq_len = Maximum allowed time steps in the target sequence.
      | tgt_batch = num_hyps_per_beam * src_batch

    Args:
      theta: A `.NestedMap` object containing weights' values of the decoder
        layer and its children layers.
      encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to
        the callbacks.
      cur_step: A scalar int tensor, the current time step, 0-based.
      step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the
        current search step.
      core_bs_states: A tuple of core beam search states. This list is
        maintained by this helper class.
      other_states: A `.NestedMap` of other beam search states. This
        `.NestedMap` is managed and updated by the client. It is expected that
        each of its member tensors are of rank >= 1. t[i, ...] is the state of
        the i-th hyp at the beginning of this search step.
      num_hyps_per_beam: Num of hyps to keep per beam.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        See class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        See class header comments for more details.

    Returns:
      A tuple of following elements for the next beam search step,
      (next step, all_done, step_ids, core_bs_states, other_states)
    """
        p = self.params

        bs_results, other_states = pre_beam_search_step_callback(
            theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam)

        (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps,
         in_done_hyps, in_atten_probs) = core_bs_states

        (out_best_scores, out_cumulative_scores, out_scores, out_hyps,
         out_prev_hyps, out_done_hyps, out_atten_probs,
         all_done) = ops.beam_search_step(
             tf.cast(bs_results.log_probs, dtype=p.dtype),
             tf.cast(bs_results.atten_probs, dtype=p.dtype),
             best_scores,
             cumulative_scores,
             in_scores,
             in_hyps,
             in_prev_hyps,
             in_done_hyps,
             in_atten_probs,
             bs_results.is_last_chunk if self._model_uses_eoc_id else [],
             cur_step,
             eoc_id=p.target_eoc_id,
             eos_id=p.target_eos_id,
             beam_size=p.beam_size,
             num_hyps_per_beam=num_hyps_per_beam,
             valid_eos_max_logit_delta=p.valid_eos_max_logit_delta,
             merge_paths=p.merge_paths,
             allow_empty_terminated_hyp=p.allow_empty_terminated_hyp,
             ensure_full_beam=p.ensure_full_beam,
             force_eos_in_last_step=p.force_eos_in_last_step,
             local_eos_threshold=p.local_eos_threshold)

        new_step_ids = tf.reshape(out_hyps[cur_step, :], tf.shape(step_ids))
        new_step_ids.set_shape(step_ids.get_shape())

        old_hyp_ids = tf.reshape(
            tf.slice(out_prev_hyps, begin=[cur_step, 0], size=[1, -1]), [-1])

        if p.batch_major_compute:
            # Transformed the indices into the key/value cache for fast decoding
            # (prefix_states in other_states) due to the num_hyps dimension of
            # cache is computed as num_beams by num_hyps_per_beam, which is different
            # from the old_hyp_ids assumption (num_hyps_per_beam by num_beams).
            # Both transpose and recomputation are required to correct the indices.
            num_beams = tf.shape(best_scores)[0]
            old_hyp_ids_in_cache_order = tf.reshape(
                tf.transpose(tf.reshape(old_hyp_ids, [num_hyps_per_beam, -1])),
                [-1])
            old_hyp_ids_in_cache_order = (
                (old_hyp_ids_in_cache_order % num_beams) * num_hyps_per_beam +
                old_hyp_ids_in_cache_order // num_beams)

        new_bs_states = (out_best_scores, out_cumulative_scores, out_scores,
                         out_hyps, out_prev_hyps, out_done_hyps,
                         out_atten_probs)

        def ReOrderHyps(x_in):
            """Reorders x_in based on prev hyp ids."""
            if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims
                    and x_in.shape.ndims > 0):
                if x_in.shape.ndims > 2 and not p.batch_major_state:
                    # Use corrected indices only here for batch major compute as key/value
                    # caches are the states being affected.
                    correct_old_hyp_ids = (old_hyp_ids_in_cache_order
                                           if p.batch_major_compute else
                                           old_hyp_ids)
                    x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1)
                else:
                    x_out = tf.gather(x_in, old_hyp_ids)
                x_out.set_shape(x_in.get_shape())
                return x_out
            else:
                return x_in

        new_other_states = other_states.Transform(ReOrderHyps)

        final_other_states = post_beam_search_step_callback(
            theta, encoder_outputs, new_step_ids, new_other_states)

        return (cur_step + 1, all_done, new_step_ids, new_bs_states,
                final_other_states)
    def ComputePredictions(self, theta, input_batch):
        p = self.params
        batch_size = p.input.batch_size
        self._shape_batch(input_batch)

        # Prepend SOS token, this is not done by the Transformer layer for you
        # since this is usually done by the input pipeline in Babelfish.
        pronunciation = self._AddStartToken(input_batch.pronunciation)

        if p.use_neighbors:
            spellings = input_batch.neighbor_spellings
            pronunciations = input_batch.neighbor_pronunciations

        inp = {
            "ids": input_batch.spelling,
        }

        if (p.use_neighbors and p.also_shuffle_neighbors
                and (p.neigh_att_type == "CONCAT" or p.use_neigh_id_emb)):
            # If we use neighbor IDs, shuffle the neighbours to stop the model
            # overfitting to the ordering of the neighbours.
            # Concat then shuffle and split so that the spelling and pronunciation
            # are shuffled the same way and the IDs are aligned.
            neighbor_info = tf.concat([spellings, pronunciations], axis=-1)
            # Transpose the max_neighbors dimension to the front and shuffle.
            neighbor_info = tf.transpose(
                tf.random.shuffle(tf.transpose(neighbor_info, (1, 2, 0))),
                (2, 0, 1))
            spellings, pronunciations = (
                neighbor_info[:, :, :p.max_spelling_len],
                neighbor_info[:, :, p.max_spelling_len:])

        if p.use_neighbors and p.neigh_att_type == "CONCAT":
            # Interleave and flatten the neighbours info
            # ->(batch_size, max_neighbors, max_spelling_len + max_pronunciation_len)
            neigh_info = tf.concat([spellings, pronunciations], axis=2)
            # ->(batch_size, max_neighbors*(max_spelling_len + max_pronunciation_len))
            neigh_info = tf.reshape(neigh_info, (batch_size, -1))

            inp["ids"] = tf.concat([inp["ids"], neigh_info], axis=1)

            # If we are just concatenating everything then the main encoder needs
            # neighbors IDs.
            neigh_ids = tf.range(p.max_neighbors)[:, tf.newaxis]
            neigh_ids = tf.tile(
                neigh_ids,
                (batch_size, p.max_spelling_len + p.max_pronunciation_len))
            neigh_ids = tf.reshape(neigh_ids, (batch_size, -1))
            # Add the ids for the main input
            main_ids = tf.tile([[p.max_neighbors]],
                               (batch_size, p.max_spelling_len))
            inp["task_ids"] = tf.concat([main_ids, neigh_ids], axis=1)

        inp["paddings"] = self._GetPaddings(inp["ids"], dtype=tf.int32)
        enc_out = self.encoder.FProp(theta.encoder, py_utils.NestedMap(inp))

        # Auxiliary inputs that the decoder can attend to, currently can be
        # neighbour summaries.
        aux_inputs = []
        aux_paddings = []

        if p.use_neighbors and p.neigh_att_type != "CONCAT":
            neigh_enc, padding = self._GetAxiliaryNeighInputs(
                spellings, pronunciations, enc_out, theta, batch_size)

            aux_inputs.extend(neigh_enc)
            aux_paddings.extend(padding)

        if aux_inputs:
            aux_inputs = tf.concat(aux_inputs, axis=0)
            aux_paddings = tf.concat(aux_paddings, axis=0)

            if p.aux_dropout_prob and not self.do_eval:
                aux_inputs = tf.nn.dropout(
                    aux_inputs,
                    p.aux_dropout_prob,
                    noise_shape=(aux_inputs.get_shape().as_list()[0],
                                 batch_size, 1))

            enc_out.encoded = tf.concat([enc_out.encoded, aux_inputs], axis=0)
            enc_out.padding = tf.concat([enc_out.padding, aux_paddings],
                                        axis=0)

        enc_out.embedded_inputs = None  # to verify this is not used
        predictions = self.decoder.ComputePredictions(
            theta.decoder, enc_out,
            py_utils.NestedMap({
                "ids":
                pronunciation,
                "paddings":
                self._GetPaddings(pronunciation),
                "weights":
                tf.ones_like(input_batch.pronunciation, dtype=tf.float32),
            }))

        beam_out = self.decoder.BeamSearchDecode(enc_out, p.beam_size)
        top_ids = tf.reshape(beam_out.topk_ids,
                             [batch_size, -1, p.max_pronunciation_len])
        # Just take the top beam decodings
        top_ids = top_ids[:, 0, :]

        if p.is_inference:
            self.BuildInferenceInfo(top_ids, input_batch.pronunciation,
                                    enc_out)
            self.per_example_tensors["beam_scores"] = beam_out.topk_scores

        self.per_example_tensors["hyp"] = top_ids
        self.per_example_tensors["cognate_id"] = input_batch.cognate_id
        self.per_example_tensors["inp"] = input_batch.spelling
        self.per_example_tensors["ref"] = input_batch.pronunciation
        if p.use_neighbors:  # Note that cannot return None!
            self.per_example_tensors[
                "neighbor_spellings"] = input_batch.neighbor_spellings
            self.per_example_tensors[
                "neighbor_pronunciations"] = input_batch.neighbor_pronunciations
        self.prediction_values = predictions
        predictions.batch = input_batch

        return predictions
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs):
    """Merges beam search hyps from multiple decoders.

  Args:
    max_hyps_per_beam: the number of top hyps in the merged results. Must be
      less than or equal to total number of input hyps.
    beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share
      the same source_batch and max sequence length.

  Returns:
    A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per
    beam.
  """
    source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0]
    value_dict = {}
    for output in beam_search_outputs:
        hyps_per_beam = py_utils.with_dependencies([
            py_utils.assert_equal(source_batch,
                                  tf.shape(output.topk_hyps)[0]),
        ],
                                                   tf.shape(
                                                       output.topk_hyps)[1])
        for k, v in six.iteritems(output._asdict()):
            if v is None:
                continue
            if k == 'done_hyps':
                v = tf.transpose(v)
            if k not in value_dict:
                value_dict[k] = []
            value_dict[k].append(
                tf.reshape(v, [source_batch, hyps_per_beam, -1]))

    # Concatenate the tensors along the 'num_hyps_per_beam' dimension.
    concatenated = {}
    for k, values in six.iteritems(value_dict):
        if len(values) != len(beam_search_outputs):
            raise ValueError('Incomplete values for %s: %s' %
                             (k, beam_search_outputs))
        concatenated[k] = tf.concat(values, axis=1)

    scores = concatenated['topk_scores']
    scores = tf.where(tf.equal(concatenated['topk_lens'], 0),
                      tf.fill(tf.shape(scores), -1e6), scores)
    scores = tf.squeeze(scores, -1)

    # Select top max_hyps_per_beam indices per beam.
    _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam)
    batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1),
                        [1, max_hyps_per_beam])
    # [source_batch, max_hyps_per_beam, 2]
    gather_indices = tf.stack([batch_ids, top_indices], axis=-1)

    # Gather the merged top hyps according to 'gather_indices'.
    top = beam_search_outputs[0]._asdict()
    total_hyps = source_batch * max_hyps_per_beam
    for k, v in six.iteritems(concatenated):
        v = tf.gather_nd(v, gather_indices)
        if k == 'done_hyps':
            v = tf.transpose(tf.reshape(v, [total_hyps, -1]))
        elif k == 'topk_hyps':
            v = tf.reshape(v, [source_batch, max_hyps_per_beam])
        elif k == 'topk_ids':
            v = tf.reshape(v, [total_hyps, -1])
        elif k in ('topk_lens', 'topk_scores', 'topk_decoded'):
            v = tf.reshape(v, [total_hyps])
        else:
            raise ValueError('Unexpected field: %s' % k)
        top[k] = v
    return BeamSearchDecodeOutput(**top)
Exemple #23
0
 def _BroadcastAcrossPoints(z):
   return tf.transpose(tf.tile(z, [1, num_points]))
Exemple #24
0
    def _ConstructWarpMatrix(self, batch_size, matrix_size, origin,
                             destination, choose_range, dtype):
        """Returns warp matrices according to origin, destination and choose_range.

    This function constructs a batch of warp matrices which maps the batch
    of origin points to the batch of destination points with fixed boundary
    coordinates at 0 and choose_range.

    The warping function, defined by the origin anchor point `origin`,
    the destination of the origin anchor point `destination` and the
    length of the domain in the warping axis `choose_range` is a piecewise
    linear map that fixes the points 0 and `choose_range` and maps
    `origin` to `destination`.

    For the warping matrix to be non-singular, destination must lie in the
    range 1<= destination <= choose_range - 1, so a destination
    out of this range is adjusted to be in this range before the warping
    matrix is constructed.

    The warping map can be explicitly written by first defining the slopes:
      1) slope_0 = origin / destination.
      2) slope_1 = (choose_range - origin) / (choose_range - destination).
      3) slope_2 = 1.0.

    Then the origin point orig_i of the mapped coordinate i is given by:
      1) i < destination: orig_i = slope_0 * i.
      2) destination <= i < choose_range:
         orig_i = slope_1 * i - (slope_1 - slope_0) * destination.
      3) i >= choose_range: orig_i = i.

    Denoting n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by:
      1) j = n_i: 1 - n_i + orig_i.
      2) j = n_i - 1: n_i - orig_i.
      3) Otherwise: 0.

    Applying the warp matrix to an array of pixels, i.e.,
    warped_pixel[i] = sum_j warp[i][j] * pixel[j], one would get
    warped_pixel[i] = (n_i-orig_i) pixel[n_i-1] + (1-n_i+orig_i) pixel[n_i].

    Args:
      batch_size: Batch size. Integer number.
      matrix_size: Dimension of the vector space the warp matrix is applied to.
        Integer number.
      origin: Origin anchor point for warping. Tensor of shape (batch_size,) and
        data type dtype.
      destination: Destination of the origin anchor point upon warping. Tensor
        of shape (batch_size,) and data type dtype.
      choose_range: Range within which the warp reference points must lie.
        Tensor of shape (batch_size,) data type dtype.
      dtype: Data type of origin, destination, choose_range and the output warp
        matrix.

    Returns:
      warp_matrix: An array of fixed size warp matrices with shape
      (batch_size, matrix_size, matrix_size).
    """
        p = self.params

        # Entries of destination must be in the range
        # 1 <= destination <= choose_range - 1
        # for warp matrix to have non-singular values.
        destination = tf.minimum(tf.maximum(destination, 1.0),
                                 choose_range - 1.0)

        # Construct piece-wise linear function fixing boundary points
        # specified by zero, choose_range and matrix size and maps
        # the origin anchor point to the destination.
        destination_bc = tf.broadcast_to(destination,
                                         (matrix_size, batch_size))
        destination_bc = tf.transpose(destination_bc)
        choose_range_bc = tf.broadcast_to(choose_range,
                                          (matrix_size, batch_size))
        choose_range_bc = tf.transpose(choose_range_bc)

        # Slopes of piece-wise linear function.
        slope_0 = origin / destination
        slope_1 = (choose_range - origin) / (choose_range - destination)
        slope_2 = 1.0

        # x is a batch of origin matrices.
        # The origin matrix is the matrix such that
        # origin[i][j] = Origin coordinate of coordinate i for the warp map.
        # Denoting the destination of the origin anchor point in the
        # warp map as "dest," the origin coordinate of point i is given by:
        # 1) i < dest: slope_0 * i.
        # 2) dest <= i < choose_range: slope_1 * i - (slope_1 - slope_0) * dest.
        # 3) i >= choose_range: i.
        x = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype),
                            (batch_size, matrix_size))
        x = (self.EinsumBBmBm(slope_0, x) + self.EinsumBBmBm(
            slope_1 - slope_0, tf.nn.relu(x - destination_bc)) +
             self.EinsumBBmBm(slope_2 - slope_1,
                              tf.nn.relu(x - choose_range_bc)))
        x = tf.broadcast_to(x, (matrix_size, batch_size, matrix_size))
        x = tf.transpose(x, perm=[1, 2, 0])

        # y is a batch of coordinate matrices.
        # A coordinate matrix is a matrix such that
        # coordinate[i][j] = j.
        y = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype),
                            (batch_size, matrix_size, matrix_size))
        # Warp matrix is obtained by applying hat function element-wise to (x-y).
        # Denoting the origin point of i under the warp map as orig_i,
        # and n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by:
        # 1) j = n_i: 1 - n_i + orig_i.
        # 2) j = n_i - 1: n_i - orig_i.
        # 3) Otherwise: 0.
        # Applying the warp matrix to pixels, i.e.,
        # warped_pixel[i] = sum_j warp[i][j] * original_pixel[j], one would get
        # warped_pixel[i] = (n_i - orig_i) * original_pixel[n_i-1]
        #                   + (1 - n_i + orig_i) * original_pixel[n_i].
        warp_matrix = x - y
        warp_matrix = _hat(warp_matrix)
        if p.fprop_dtype is not None and p.fprop_dtype != dtype:
            warp_matrix = tf.cast(warp_matrix, p.fprop_dtype)

        return warp_matrix
Exemple #25
0
  def testDecoderFPropWithAdapters(self):
    """Create decoder with adapters, and verify that FProp runs."""
    with self.session(use_gpu=False):
      tf.random.set_seed(8372749040)

      params = _DecoderParams(
          num_rnn_layers=2,
          vn_config=py_utils.VariationalNoiseParams(
              None, True, False, seed=12345))
      params.rnn_cell_dim = 3
      params.adapter_layer_tpl.Set(
          bottleneck_dim=4,
          num_tasks=16,
          projection_params_init=py_utils.WeightInit.Gaussian(0.01))
      params.adapter_task_id_field = 'domain_ids'

      dec = params.Instantiate()
      src_seq_len = 5
      src_enc = tf.random.normal([src_seq_len, 2, 8],
                                 seed=982774838,
                                 dtype=py_utils.FPropDtype(params))
      src_enc_padding = tf.constant(
          [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
          dtype=py_utils.FPropDtype(params))
      domain_ids = tf.constant(np.random.randint(low=0, high=16, size=[2]))
      encoder_outputs = py_utils.NestedMap(
          encoded=src_enc, padding=src_enc_padding, domain_ids=domain_ids)
      # shape=[4, 5]
      target_ids = tf.transpose(
          tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15],
                       [5, 6, 7, 8], [10, 5, 2, 5]],
                      dtype=tf.int32))
      # shape=[4, 5]
      target_labels = tf.transpose(
          tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13],
                       [5, 7, 8, 10], [10, 5, 2, 4]],
                      dtype=tf.int32))
      # shape=[4, 5]
      target_paddings = tf.transpose(
          tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0],
                       [1, 1, 1, 0]],
                      dtype=py_utils.FPropDtype(params)))
      target_transcripts = tf.constant(['abcd', 'bcde', 'klmp', 'fghi', 'kfcf'])
      target_weights = 1.0 - target_paddings
      # ids/labels/weights/paddings are all in [batch, time] shape.
      targets = py_utils.NestedMap({
          'ids': target_ids,
          'labels': target_labels,
          'weights': target_weights,
          'paddings': target_paddings,
          'transcripts': target_transcripts,
      })
      decoder_outputs = dec.FPropDefaultTheta(encoder_outputs, targets)
      metrics = decoder_outputs.metrics
      per_sequence_loss = decoder_outputs.per_sequence['loss']

      self.assertIn('fraction_of_correct_next_step_preds', metrics)
      self.evaluate(tf.global_variables_initializer())
      metrics_val, per_sequence_loss_val = self.evaluate(
          [metrics, per_sequence_loss])
      tf.logging.info('metrics=%s, per_sequence_loss=%s', metrics_val,
                      per_sequence_loss_val)

      self.assertEqual(metrics_val['loss'], metrics_val['log_pplx'])
      # Target batch size is 4. Therefore, we should expect 4 here.
      self.assertEqual(per_sequence_loss_val.shape, (4,))
Exemple #26
0
    def _testDecoderFPropGradientCheckerHelper(self, func_inline=False):
        config = tf.ConfigProto(graph_options=tf.GraphOptions(
            optimizer_options=tf.OptimizerOptions(
                do_function_inlining=func_inline)))
        with self.session(use_gpu=False, config=config) as sess:
            tf.set_random_seed(8372749040)
            np.random.seed(274854)
            vn_config = py_utils.VariationalNoiseParams(None, False, False)
            p = self._DecoderParams(vn_config)
            p.dtype = tf.float64

            dec = p.Instantiate()
            src_seq_len = 5
            src_enc = tf.constant(np.random.uniform(size=(src_seq_len, 2, 8)),
                                  tf.float64)
            src_enc_padding = tf.constant(
                [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
                dtype=tf.float64)
            encoder_outputs = py_utils.NestedMap(encoded=src_enc,
                                                 padding=src_enc_padding)
            target_ids = tf.transpose(
                tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15],
                             [5, 6, 7, 8], [10, 5, 2, 5]],
                            dtype=tf.int32))
            target_labels = tf.transpose(
                tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13],
                             [5, 7, 8, 10], [10, 5, 2, 4]],
                            dtype=tf.int32))
            target_paddings = tf.transpose(
                tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0],
                             [0, 1, 0, 0], [1, 1, 1, 1]],
                            dtype=tf.float64))
            target_transcripts = tf.constant(
                ['abcd', 'bcde', 'klmp', 'fghi', 'kfcf'])
            target_weights = 1.0 - target_paddings

            targets = py_utils.NestedMap({
                'ids': target_ids,
                'labels': target_labels,
                'weights': target_weights,
                'paddings': target_paddings,
                'transcripts': target_transcripts,
            })
            metrics = dec.FPropDefaultTheta(encoder_outputs, targets).metrics
            loss = metrics['loss'][0]
            all_vars = tf.trainable_variables()
            grads = tf.gradients(loss, all_vars)

            def DenseGrad(var, grad):
                if isinstance(grad, tf.Tensor):
                    return grad
                elif isinstance(grad, tf.IndexedSlices):
                    return tf.unsorted_segment_sum(grad.values, grad.indices,
                                                   tf.shape(var)[0])

            dense_grads = [DenseGrad(x, y) for (x, y) in zip(all_vars, grads)]

            tf.global_variables_initializer().run()

            test_utils.CompareToGoldenSingleFloat(self, 3.458078, loss.eval())
            # Second run to make sure the function is determistic.
            test_utils.CompareToGoldenSingleFloat(self, 3.458078, loss.eval())

            symbolic_grads = [x.eval() for x in dense_grads if x is not None]
            numerical_grads = []
            for v in all_vars:
                numerical_grads.append(
                    test_utils.ComputeNumericGradient(sess, loss, v))

            for x, y in zip(symbolic_grads, numerical_grads):
                self.assertAllClose(x, y)
Exemple #27
0
  def FProp(self, theta, input_batch):
    """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      input_batch: A `.NestedMap` object containing: ids - The inputs tensor of
        shape [batch, time]. paddings - The ids' paddings of shape [batch,
        time].

    Returns:
      A '.NestedMap' object containing:
        encoded - The encoded features of shape [time, batch, dim] or [batch,
          time, dim], depending p.output_data_format.
        padding - The encoded features' padding of shape [time, batch] or
          [batch, time].
        segment_id - The segmentation of packed inputs of shape [time, batch] or
          [batch, time] if it is supported by the model, or None otherwise.
        embedded_inputs - The embedded inputs tokens without positional
          encodings of shape [time, batch, dim] or [batch, time, dim].
    """

    p = self.params
    with tf.name_scope(p.name):
      # [batch, time]
      input_ids = input_batch.ids
      # [batch, time]
      paddings = input_batch.paddings

      # [batch, time]
      segment_ids = input_batch.segment_ids if p.packed_input else None

      batch = py_utils.GetShape(input_ids)[0]
      time = py_utils.GetShape(input_ids)[1]

      # Embedding layer.
      # [batch, time, dim]
      if not p.shared_emb:
        input_embs = self.token_emb.EmbLookup(theta.token_emb, input_ids)
      else:
        input_embs = self.softmax.EmbLookup(theta.softmax, input_ids)
      orig_input_embs = input_embs

      # [1, time, dim]
      if p.packed_input:
        positions = input_batch.segment_pos
        position_embs = tf.expand_dims(
            self.position_emb.FPropWithPosition(theta.position_emb, positions),
            0)
      else:
        position_embs = tf.expand_dims(
            self.position_emb.FProp(theta.position_emb, time), 0)

      # [batch, time, dim]
      input_embs += position_embs

      if p.input_dropout_tpl.fprop_dtype:
        input_embs = tf.cast(input_embs, p.input_dropout_tpl.fprop_dtype)
        paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype)

      input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs)
      # [batch, time, dim]
      transformer_input = input_embs
      # Explicitly set the input shape of Transformer layers, to avoid
      # unknown shape error occurred to tf.einsum on nonTPU devices.
      transformer_input = tf.reshape(transformer_input,
                                     [batch, time, p.model_dim])

      # Compute self-attention segment mask once.
      if p.packed_input:
        segment_mask = batch_major_attention.SegmentMask(
            segment_ids, segment_ids, dtype=transformer_input.dtype)
      else:
        segment_mask = tf.zeros([batch, 1, time, time])

      encoded, padding = self.transformer_stack.FProp(theta.transformer_stack,
                                                      transformer_input,
                                                      paddings, segment_mask)

      if p.final_layer_norm:
        encoded = self.final_ln.FProp(theta.final_ln, encoded)

      seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1), tf.int32)

      if p.output_data_format == 'TBC':
        encoded = tf.transpose(encoded, [1, 0, 2])  # [time, batch, dim]
        padding = tf.transpose(padding)  # [time, batch]
        segment_ids = tf.transpose(segment_ids) if p.packed_input else None
        orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2])

      return py_utils.NestedMap(
          encoded=encoded,
          padding=padding,
          seq_lengths=seq_lengths,  # used by beam_search_helper.
          segment_id=segment_ids,
          embedded_inputs=orig_input_embs)
Exemple #28
0
def FarthestPointSampler(points,
                         padding,
                         num_sampled_points,
                         precomputed_squared_distance=None,
                         num_seeded_points=0,
                         random_seed=None):
    """Samples num_sampled_points from points using farthest point sampling.

  Algorithm:
  1. Start by selecting a random point and adding to a selected set.
  2. For all remaining points, find the furthest point from those selected.
  3. Add furthest point to selected.
  4. Repeat 2-3 until num_sampled_points are selected.

  More details at https://en.wikipedia.org/wiki/Farthest-first_traversal

  This output of this function can be used with tf.batch_gather to extract the
  desired points, for example: tf.batch_gather(points, sampled_idx)

  Args:
    points: floating point tf.Tensor of shape [N, P1, dims]
    padding: A floating point tf.Tensor of shape [N, P1] with 0 if the point is
      real, and 1 otherwise.
    num_sampled_points: integer number of points to sample.
    precomputed_squared_distance: optional tf.Tensor of shape [N, P1, P1] of
      distances between each point. if None, distances will be computed on the
      fly.
    num_seeded_points: If num_seeded_points > 0, then the first
      num_seeded_points in points are considered to be seeded in the FPS
      sampling. Note that we assume that these points are *not* padded, and do
      not check padding when seeding them.
    random_seed: optional integer random seed to use with all the random ops.

  Returns:
    A tuple of tf.Tensors (sampled_idx, closest_idx) of types
    (tf.int32, tf.int32).

    sampled_idx is of shape [N, num_sampled_points] representing the indices
    selected using the sampler. This will have range of [0, P1].

    closest_idx is of shape [N, P1] representing the indices of the closest
    sampled points for each input point. closest_idx is used in PCNN as part of
    the pooling operation: each point is assigned to the closest sampled point
    and a max is taken over them. This will have a range of [0, P2] with the
    index of the closest sampled point that remains.
  """
    points = py_utils.HasRank(points, 3)
    batch_size, num_points, dims = py_utils.GetShape(points, 3)

    points = py_utils.with_dependencies(
        [py_utils.assert_greater_equal(num_points, num_sampled_points)],
        points)

    # Add a tiny bit of noise to the distance matrix or points so all
    # points are unique. This will also ensure true repeated points
    # like padded points are only selected after all valid points are selected.
    if precomputed_squared_distance is not None:
        precomputed_squared_distance = py_utils.HasShape(
            precomputed_squared_distance, [batch_size, num_points, num_points])
        precomputed_squared_distance += tf.random.uniform(
            (batch_size, num_points, 1),
            minval=1e-6,
            maxval=1e-5,
            dtype=tf.float32,
            seed=random_seed)
    else:
        points += tf.random.uniform((batch_size, num_points, dims),
                                    minval=1e-6,
                                    maxval=1e-5,
                                    dtype=tf.float32,
                                    seed=random_seed)

    # TensorArray to store the sampled indices in the loop.
    sampled_idx = tf.TensorArray(tf.int32, num_sampled_points)

    # Initialize distance_to_selected to inf for all points.
    distance_to_selected = float('inf') * tf.ones((batch_size, num_points))

    # For tracking the index to the closest selected point.
    closest_idx = tf.zeros((batch_size, num_points), dtype=tf.int32)

    # Current loop index counter.
    curr_idx = tf.constant(0, dtype=tf.int32)

    # Get number of valid points (1 is padded, so num_points - num_padded).
    num_valid_points = tf.cast(tf.cast(num_points, dtype=tf.float32) -
                               tf.reduce_sum(padding, axis=1),
                               dtype=tf.int32)

    def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx):
        """Loop body for farthest point sampler."""
        def _GetRandomRealPoint():
            """Select the first point.

      For the first point, we want any random real (non padded) point, so we
      create a random values per point, and then set all padded ones to
      some large value (more than the maxval). We then take the min per batch
      element to get the first points.

      Returns:
        Tensor containing the index of a random point selected for each example
        in the batch.
      """
            random_values = tf.random.uniform((batch_size, num_points),
                                              minval=0,
                                              maxval=1,
                                              dtype=tf.float32,
                                              seed=random_seed)
            random_values = tf.where(tf.equal(padding, 0.0), random_values,
                                     padding * 10)
            return tf.argmin(random_values, axis=1, output_type=tf.int32)

        def _GetFurthestPoint():
            """Get point that is furthest from those already selected.

      We also bias the sampling towards real points by setting the distance
      to padded points negative until we are out of real points.

      Returns:
        Tensor containing the index of the next farthest point selected for each
        example in the batch.
      """
            # Set padded points distance to negative so they aren't selected.
            padding_masked_distance_to_selected = tf.where(
                tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones(
                    (batch_size, num_points), dtype=tf.float32))
            # But only do this when we still have valid points left.
            padding_masked_distance_to_selected = tf.where(
                tf.less(curr_idx, num_valid_points),
                padding_masked_distance_to_selected, distance_to_selected)
            return tf.argmax(padding_masked_distance_to_selected,
                             axis=-1,
                             output_type=tf.int32)

        def _GetSeededPoint():
            """Select a seeded point.

      Seeded points are assumed to be at the beginning of the original points.

      Returns:
        Tensor containing the index of the next seeded point to select for each
        example in the batch.
      """
            return tf.ones((batch_size, ), dtype=tf.int32) * curr_idx

        # Select indices for this loop iteration.
        def _Seeded():
            return tf.cond(tf.less(curr_idx, num_seeded_points),
                           _GetSeededPoint, _GetFurthestPoint)

        def _Real():
            return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint,
                           _GetFurthestPoint)

        new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded,
                               _Real)
        sampled_idx = sampled_idx.write(curr_idx, new_selected)

        # Extract the distance to the latest point selected to update
        # distance_to_selected.
        new_selected_gather_idx = tf.stack(
            [tf.range(batch_size), new_selected], axis=1)
        if precomputed_squared_distance is not None:
            new_distance = tf.gather_nd(precomputed_squared_distance,
                                        new_selected_gather_idx)
        else:
            new_points = tf.reshape(
                tf.gather_nd(points, new_selected_gather_idx),
                [batch_size, 1, dims])
            new_distance = tf.reshape(
                SquaredDistanceMatrix(points, new_points),
                [batch_size, num_points])

        is_newly_closest = tf.less(new_distance, distance_to_selected)
        distance_to_selected = tf.minimum(distance_to_selected, new_distance)

        # Track the index to the closest selected point.
        new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points])
        closest_idx = tf.cond(
            tf.equal(curr_idx, 0),
            # At the first loop iteration, the init points are the closest.
            lambda: new_selected_tiled,
            # Otherwise, update with the new points based on the distances.
            lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx)
        )
        return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx

    _, _, sampled_idx, closest_idx = tf.while_loop(
        lambda curr_idx, *args: tf.less(curr_idx, num_sampled_points),
        _BodyFn,
        loop_vars=(curr_idx, distance_to_selected, sampled_idx, closest_idx),
        back_prop=False,
        maximum_iterations=num_sampled_points)

    sampled_idx = sampled_idx.stack()  # num_sampled_points x n
    sampled_idx = tf.transpose(sampled_idx, [1, 0])

    if isinstance(batch_size, int) and isinstance(num_sampled_points, int):
        sampled_idx.set_shape((batch_size, num_sampled_points))

    return sampled_idx, closest_idx
Exemple #29
0
    def ComputePredictions(self,
                           encoder_outputs,
                           pronunciations,
                           is_inference=False):
        """Computes the predictions from the encoder_outputs, updating losses.

    Despite the name, this function does the bulk of the decoding and loss
    computation, incrementing the loss at each time step.

    Args:
      encoder_outputs: a NestedMap consisting of outputs of the
        FeatureNeighborhoodEncoder with  encoded - encoding of the input
        spelling
        neighbor_pronunciations_encoded - encodings of the neighbor prons
        neighbor_pronunciations_encoded - encodings of the neighbor spellings
        state - encoder state to which has been added dec_input - seed output
        for the decoder [*, 1] tensor consisting of sentence start indices
        (corresponding to "<s>")
      pronunciations: NestedMap with pronunciations - [*, max_pronunciation_len]
        tensor of pronunciations
      is_inference: If False then uses teacher forcing else does autoregression.

    Returns:
      NestedMap with loss, per_sequence_losses,labels, a
      [*, max_pronunciation_len] tensor of predictions, and attention
      ([*, max_pronunciation_len, max_spelling_len]), and
      neighbor_attention ([*, max_pronunciation_len, max_neighbors])
      tensors, along with the raw batch passed through from the encoder.
    """
        p = self.params
        targets = pronunciations.pronunciations
        t_len = int(targets.get_shape().as_list()[1])
        t_idx = tf.constant(0)
        attention = tf.TensorArray(dtype=tf.float32, size=t_len)
        neighbor_attention = tf.TensorArray(dtype=tf.float32, size=t_len)

        outputs = tf.TensorArray(dtype=tf.float32, size=t_len)

        loop_cond = lambda t_idx, ts, *_: tf.less(t_idx, t_len)

        dec_input = tf.convert_to_tensor([p.start] * p.input.batch_size)
        state = encoder_outputs.state

        # pylint: disable=missing-docstring
        def loop_body(t_idx, dec_input, attention, neighbor_attention, state,
                      outputs):
            decoder_result = self.Decode(encoder_outputs, dec_input, state)

            outputs = outputs.write(t_idx, decoder_result.predictions)
            attention = attention.write(t_idx,
                                        decoder_result.attention_weights)
            neighbor_attention = neighbor_attention.write(
                t_idx,
                tf.cast(decoder_result.neighbor_attention_weights,
                        dtype=tf.float32))

            if is_inference:
                dec_input = tf.cast(tf.argmax(decoder_result.predictions, 1),
                                    tf.int32)
            else:
                dec_input = targets[:, t_idx]
            t_idx = t_idx + 1
            state = decoder_result.state
            return t_idx, dec_input, attention, neighbor_attention, state, outputs

        _, _, attention, neighbor_attention, state, outputs = tf.while_loop(
            loop_cond,
            loop_body,
            loop_vars=[
                t_idx, dec_input, attention, neighbor_attention, state, outputs
            ])

        outputs = tf.transpose(outputs.stack(), [1, 0, 2])
        labels = tf.argmax(outputs, axis=-1)
        mask = tf.cast(tf.math.logical_not(tf.math.equal(targets, 0)),
                       dtype=tf.float32)
        loss = self._loss_object(targets, outputs, sample_weight=mask)
        loss = tf.reduce_sum(loss, axis=1)
        per_sequence_losses = (loss / t_len)
        loss = tf.reduce_mean(per_sequence_losses)
        predictions = py_utils.NestedMap()
        predictions.loss = loss
        predictions.per_sequence_losses = per_sequence_losses
        predictions.labels = labels
        predictions.attention = tf.transpose(tf.squeeze(attention.stack()),
                                             perm=[1, 0, 2])
        if p.use_neighbors:
            predictions.neighbor_attention = tf.transpose(tf.squeeze(
                neighbor_attention.stack()),
                                                          perm=[1, 0, 2])
        else:
            predictions.neighbor_attention = tf.squeeze(
                neighbor_attention.stack())
        # Expose this for subsequent data analysis
        predictions.batch = encoder_outputs.batch
        return predictions