예제 #1
0
def IsWithinBBox3D(points_3d, bboxes_3d):
    """Checks if points are within a 3-d bbox.

  Args:
    points_3d: [num_points, 3] float32 Tensor specifying points in 3-d space as
      [x, y, z] coordinates.
    bboxes_3d: [num_bboxes, 7] float32 Tensor specifying a 3-d bboxes specified
      as [x, y, z, dx, dy, dz, phi] where x, y and z is the center of the box.

  Returns:
    boolean Tensor of shape [num_points, num_bboxes] indicating whether the
    points belong within each box.
  """
    points_3d = py_utils.HasRank(points_3d, 2)
    points_3d = py_utils.HasShape(points_3d, [-1, 3])
    num_points, _ = py_utils.GetShape(points_3d, 2)

    bboxes_3d = py_utils.HasRank(bboxes_3d, 2)
    bboxes_3d = py_utils.HasShape(bboxes_3d, [-1, 7])
    num_bboxes, _ = py_utils.GetShape(bboxes_3d, 2)

    # Compute the 3-D corners of the bounding boxes.
    bboxes_3d_b = tf.expand_dims(bboxes_3d, 0)
    bbox_corners = BBoxCorners(bboxes_3d_b)
    bbox_corners = py_utils.HasShape(bbox_corners, [1, -1, 8, 3])
    # First four points are the top of the bounding box.
    # Counter-clockwise arrangement of points specifying 2-d Euclidean box.
    #   (x0, y1) <--- (x1, y1)
    #                    ^
    #                    |
    #                    |
    #   (x0, y0) ---> (x1, y0)
    bboxes_2d_corners = bbox_corners[0, :, 0:4, 0:2]
    bboxes_2d_corners = py_utils.HasShape(bboxes_2d_corners, [-1, 4, 2])
    # Determine if points lie within 2-D (x, y) plane for all bounding boxes.
    points_2d = points_3d[:, :2]
    is_inside_2d = IsWithinBBox(points_2d, bboxes_2d_corners)
    is_inside_2d = py_utils.HasShape(is_inside_2d, [num_points, num_bboxes])

    # Determine if points lie with the z-dimension for all bounding boxes.
    [_, _, z, _, _, dz, _] = tf.split(bboxes_3d, 7, axis=-1)

    def _ComputeLimits(center, width):
        left = center - width / 2.0
        right = center + width / 2.0
        return left, right

    z0, z1 = _ComputeLimits(z, dz)
    z_points = tf.expand_dims(points_3d[:, 2], -1)

    def _BroadcastAcrossPoints(z):
        return tf.transpose(tf.tile(z, [1, num_points]))

    is_inside_z = tf.logical_and(
        tf.less_equal(z_points, _BroadcastAcrossPoints(z1)),
        tf.greater_equal(z_points, _BroadcastAcrossPoints(z0)))
    is_inside_z = py_utils.HasShape(is_inside_z, [num_points, num_bboxes])

    return tf.logical_and(is_inside_z, is_inside_2d)
예제 #2
0
파일: pruning.py 프로젝트: snsun/lingvo
 def maybe_update_masks():
     with tf.name_scope(self._spec.name):
         is_step_within_pruning_range = tf.logical_and(
             tf.greater_equal(self._global_step,
                              self._spec.begin_pruning_step),
             # If end_pruning_step is negative, keep pruning forever!
             tf.logical_or(
                 tf.less_equal(self._global_step,
                               self._spec.end_pruning_step),
                 tf.less(self._spec.end_pruning_step, 0)))
         is_pruning_step = tf.less_equal(
             tf.add(self._last_update_step,
                    self._spec.pruning_frequency), self._global_step)
         return tf.logical_and(is_step_within_pruning_range,
                               is_pruning_step)
예제 #3
0
 def Step(recurrent_theta, state0, inputs):
   """Computes one decoder step."""
   del inputs
   with tf.name_scope('single_sampler_step'):
     # Compute logits and states.
     bs_result, bs_state1 = pre_step_callback(
         recurrent_theta.theta,
         recurrent_theta.encoder_outputs,
         tf.expand_dims(state0.ids, 1),  # [batch, 1].
         state0.bs_state,
         num_hyps_per_beam=1)
     batch = tf.shape(bs_result.log_probs)[0]
     state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
     state1.logits = bs_result.log_probs
     # Sample ids from logits. [batch].
     state1.ids = tf.reshape(
         tf.random.stateless_multinomial(
             state1.logits / p.temperature,
             num_samples=1,
             seed=tf.stack([recurrent_theta.random_seed, state0.timestep]),
             output_dtype=state0.ids.dtype,
             name='sample_next_id'), [batch])
     if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
       state1.ids = tf.where(
           tf.logical_and(bs_result.is_last_chunk,
                          tf.equal(state1.ids, p.target_eoc_id)),
           tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids)
     state1.bs_state = post_step_callback(recurrent_theta.theta,
                                          recurrent_theta.encoder_outputs,
                                          state1.ids, bs_state1)
   return state1, py_utils.NestedMap()
예제 #4
0
    def _inverse_pth_root_graph(self, epsilon):
        graph = tf.Graph()
        with graph.as_default():
            exponent_t = tf.reshape(
                tf.placeholder(dtype=tf.float32, name="exponent", shape=None),
                [])
            # Apply exponent multiplier.
            exponent_t = exponent_t * self._exponent_multiplier
            input_t = tf.placeholder(dtype=tf.float32,
                                     name="input",
                                     shape=None)
            # For p = 2, 4 or 8, we use the iterative Newton-Schur method for
            # computing the inverse-pth root.
            either_p_2_4_8 = tf.logical_or(
                tf.logical_or(tf.equal(-1.0 / exponent_t, 2),
                              tf.equal(-1.0 / exponent_t, 4)),
                tf.equal(-1.0 / exponent_t, 8))
            # 4096 is the larger dimension SVD is tractable for.
            greater_than_4096 = tf.greater(tf.shape(input_t)[0], 4096)
            run_specialized_iterative_method = tf.logical_and(
                greater_than_4096, either_p_2_4_8)
            specialized_fn = functools.partial(
                self._specialized_inverse_pth_root, input_t, exponent_t,
                epsilon)
            generalized_fn = functools.partial(
                self._generalized_inverse_pth_root, input_t, exponent_t,
                epsilon)
            output, diff = tf.cond(run_specialized_iterative_method,
                                   specialized_fn, generalized_fn)

            tf.identity(output, "output")
            tf.identity(tf.cast(diff, tf.float32), "diff")
        return graph.as_graph_def().SerializeToString()
예제 #5
0
def IsWithinBBox(points, bbox):
    """Checks if points are within a 2-d bbox.

  The function returns true if points are strictly inside the box. It also
  returns true when the points are exactly on the box edges.

  Args:
    points: a float Tensor of shape [..., 2] of points to be tested. The last
      coordinates are (x, y).
    bbox: a float Tensor of shape [..., 4, 2] of bboxes. The last coordinates
      are the four corners of the bbox and (x, y). The corners are assumed to be
      given in counter-clockwise order.

  Returns:
    Tensor: If ``pshape = tf.shape(points)[:-1]`` and
    ``bshape = tf.shape(bbox)[:-2]``, returns a boolean tensor of shape
    ``tf.concat(pshape, bshape)``, where each element is true if the point is
    inside to the corresponding box.  If a point falls exactly on an edge of the
    bbox, it is also true.
  """
    bshape = py_utils.GetShape(bbox)[:-2]
    pshape = py_utils.GetShape(points)[:-1]
    bbox = py_utils.HasShape(bbox, bshape + [4, 2])
    points = py_utils.HasShape(points, pshape + [2])
    # Enumerate all 4 edges:
    v1, v2, v3, v4 = (bbox[..., 0, :], bbox[..., 1, :], bbox[...,
                                                             2, :], bbox[...,
                                                                         3, :])
    v1v2v3_check = tf.reduce_all(_IsCounterClockwiseDirection(v1, v2, v3))
    v2v3v4_check = tf.reduce_all(_IsCounterClockwiseDirection(v2, v3, v4))
    v4v1v2_check = tf.reduce_all(_IsCounterClockwiseDirection(v4, v1, v2))
    v3v4v1_check = tf.reduce_all(_IsCounterClockwiseDirection(v3, v4, v1))
    with tf.control_dependencies([
            py_utils.Assert(v1v2v3_check, [v1, v2, v3]),
            py_utils.Assert(v2v3v4_check, [v3, v3, v4]),
            py_utils.Assert(v4v1v2_check, [v4, v1, v2]),
            py_utils.Assert(v3v4v1_check, [v3, v4, v1])
    ]):
        is_inside = tf.logical_and(
            tf.logical_and(_IsOnLeftHandSideOrOn(points, v1, v2),
                           _IsOnLeftHandSideOrOn(points, v2, v3)),
            tf.logical_and(_IsOnLeftHandSideOrOn(points, v3, v4),
                           _IsOnLeftHandSideOrOn(points, v4, v1)))
    # Swap the last two dimensions.
    ndims = is_inside.shape.ndims
    return tf.transpose(is_inside,
                        list(range(ndims - 2)) + [ndims - 1, ndims - 2])
예제 #6
0
    def _GetMask(self,
                 batch_size,
                 max_length,
                 choose_range,
                 mask_size,
                 dtype=tf.float32,
                 max_ratio=1.0):
        """Returns a fixed size mask starting from a random position.

    In this function:
      1) Sample random lengths less than max_length with shape (batch_size,).
      2) Truncate lengths to a max of range * max_ratio, so that each mask is
         fully contained within the corresponding sequence.
      3) Random sample start points with in (choose_range - lengths).
      4) Return a mask where (lengths - start points) * mask_size are all zeros.

    Args:
      batch_size: batch size.
      max_length: Maximum number of allowed consecutive masked entries.
      choose_range: Range within which the masked entries must lie.
      mask_size: Size of the mask.
      dtype: Data type.
      max_ratio: Maximum portion of the entire range allowed to be masked.

    Returns:
      mask: a fixed size mask starting from a random position with shape
      [batch_size, seq_len].
    """
        p = self.params
        # Make sure the sampled length was smaller than max_ratio * length_bound.
        # Note that sampling in this way was biased
        # (shorter sequence may over-masked.)
        length_bound = tf.cast(choose_range, dtype=dtype)
        length_bound = tf.cast(max_ratio * length_bound, dtype=tf.int32)
        length = tf.minimum(max_length, tf.maximum(length_bound, 1))
        # Choose starting point
        random_start = tf.random.uniform((batch_size, ),
                                         maxval=1.0,
                                         seed=p.random_seed)
        start_with_in_valid_range = random_start * tf.cast(
            (choose_range - length + 1), dtype=dtype)
        start = tf.cast(start_with_in_valid_range, tf.int32)
        end = start + length - 1

        # Shift starting and end point by small value
        delta = tf.constant(0.1)
        start = tf.expand_dims(tf.cast(start, dtype) - delta, -1)
        end = tf.expand_dims(tf.cast(end, dtype) + delta, -1)

        # Construct mask
        diagonal = tf.tile(
            tf.expand_dims(tf.cast(tf.range(mask_size), dtype=dtype), 0),
            [batch_size, 1])
        mask = 1.0 - tf.cast(tf.logical_and(diagonal < end, diagonal > start),
                             dtype=dtype)
        if p.fprop_dtype is not None and p.fprop_dtype != p.dtype:
            mask = tf.cast(mask, p.fprop_dtype)
        return mask
예제 #7
0
            def ApplyBias():
                """Bias and update log_probs and consistent."""
                def TileForBeamAndFlatten(tensor):
                    tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                    tensor = tf.tile(tensor,
                                     [num_hyps_per_beam, 1
                                      ])  # [num_hyps_per_beam, src_batch]
                    tgt_batch = tf.shape(step_ids)[
                        0]  # num_hyps_per_beam*src_batch
                    return tf.reshape(tensor, [tgt_batch])

                # Consistent if step_ids == labels from previous step
                # TODO(navari): Consider updating consistent only if weights > 0. Then
                # re-evaluate the need for bias_only_if_consistent=True.
                # Note that prev_label is incorrrect for step 0 but is overridden later
                prev_label = TileForBeamAndFlatten(
                    tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
                is_step0 = tf.equal(time_step, 0)
                local_consistence = tf.logical_or(
                    is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
                consistent = tf.logical_and(states.consistent,
                                            local_consistence)

                # get label, weight slices corresponding to current time_step
                label = TileForBeamAndFlatten(
                    tf.gather(labels, time_step, axis=1))
                weight = TileForBeamAndFlatten(
                    tf.gather(weights, time_step, axis=1))
                if p.bias_only_if_consistent:
                    weight = weight * tf.cast(consistent, p.dtype)

                # convert from dense label to sparse label probs
                vocab_size = tf.shape(bs_results.log_probs)[1]
                uncertainty = tf.constant(
                    1e-10,
                    p.dtype)  # avoid 0 probs which may cause issues with log
                label_probs = tf.one_hot(
                    label,
                    vocab_size,
                    on_value=1 - uncertainty,
                    off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
                    dtype=p.dtype)  # [tgt_batch, vocab_size]
                pred_probs = tf.exp(bs_results.log_probs)

                # interpolate predicted probs and label probs
                weight = tf.expand_dims(weight, 1)
                probs = py_utils.with_dependencies([
                    py_utils.assert_less_equal(weight, 1.),
                    py_utils.assert_greater_equal(weight, 0.)
                ], (1.0 - weight) * pred_probs + weight * label_probs)
                return tf.log(probs), consistent
예제 #8
0
  def Filter(self, outputs):
    """Optionally filters the data based on context info."""
    p = self.params
    if p.equality_filters is None:
      return 1

    allowed_example = tf.convert_to_tensor(True)
    for filter_key, filter_values in p.equality_filters:
      if filter_key not in outputs:
        raise ValueError(
            'Filter key `{}` not found in extracted data.'.format(filter_key))
      has_allowed_data = tf.reduce_any(
          tf.equal(outputs[filter_key], filter_values))
      allowed_example = tf.logical_and(allowed_example, has_allowed_data)

    not_allowed_example = 1 - tf.cast(allowed_example, tf.int32)
    return 1 + (not_allowed_example * input_extractor.BUCKET_UPPER_BOUND)
예제 #9
0
 def _ShouldMerge(unused_tokens, candidates):
     """Merge until not possible, or we abort early according to merge_prob."""
     return tf.logical_and(
         tf.reduce_any(tf.not_equal(candidates, NO_TOKEN)),
         tf.random.uniform([]) < self._merge_prob)
예제 #10
0
 def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states,
                  unused_other_states_list):
   return tf.logical_and(cur_step < max_steps, tf.logical_not(all_done))
예제 #11
0
 def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids,
                  unused_hyp_lens, done_hyps, unused_other_states_list):
     return tf.logical_and(cur_step < max_steps,
                           tf.logical_not(tf.reduce_all(done_hyps)))
예제 #12
0
    def _GetMask(self,
                 batch_size,
                 choose_range,
                 mask_size,
                 global_seed,
                 max_length=None,
                 masks_per_frame=0.0,
                 multiplicity=1,
                 dtype=tf.float32,
                 max_ratio=1.0):
        """Returns fixed size multi-masks starting from random positions.

    A multi-mask is a mask obtained by applying multiple masks.

    This function when max_length is given:
      1) Sample random mask lengths less than max_length with shape
         (batch_size, multiplicity).
      2) Truncate lengths to a max of (choose_range * max_ratio),
         so that each mask is fully contained within the corresponding sequence.
      3) Random sample start points of shape (batch_size, multiplicity)
         with in (choose_range - lengths).
      4) For each batch, multiple masks (whose number is given by the
         multiplicity) are constructed.
      5) Return a mask of shape (batch_size, mask_size) where masks are
         obtained by composing the masks constructed in step 4).
         If masks_per_frame > 0, the number is given by
         min(masks_per_frame * choose_range, multiplicity).
         If not, all the masks are composed. The masked regions are set to zero.

    This function when max_length is not given:
      1) Sample random mask lengths less than (choose_range * max_ratio)
         with shape (batch_size, multiplicity).
      2) Proceed to steps 3), 4) and 5) of the above.

    Args:
      batch_size: Batch size. Integer number.
      choose_range: Range within which the masked entries must lie. Tensor of
        shape (batch_size,).
      mask_size: Size of the mask. Integer number.
      global_seed: an integer seed tensor for stateless random ops.
      max_length: Maximum number of allowed consecutive masked entries. Integer
        number or None.
      masks_per_frame: Number of masks per frame. Float number. If > 0, the
        multiplicity of the mask is set to be masks_per_frame * choose_range.
      multiplicity: Maximum number of total masks. Integer number.
      dtype: Data type.
      max_ratio: Maximum portion of the entire range allowed to be masked. Float
        number.

    Returns:
      mask: a fixed size multi-mask starting from a random position with shape
      (batch_size, mask_size).
    """
        p = self.params
        # Non-empty random seed values are only used for testing or when using
        # stateless random ops. seed_1 and seed_2 are set separately to avoid
        # correlation of mask size and mask position.
        if p.use_input_dependent_random_seed:
            seed_1 = global_seed + 1
            seed_2 = global_seed + 2
        elif p.random_seed:
            seed_1 = p.random_seed + 1
            seed_2 = 2 * p.random_seed
        else:
            seed_1 = p.random_seed
            seed_2 = p.random_seed
        # Sample lengths for multiple masks.
        if max_length and max_length > 0:
            max_length = tf.broadcast_to(tf.cast(max_length, dtype),
                                         (batch_size, ))
        else:
            max_length = tf.cast(choose_range, dtype=dtype) * max_ratio
        random_uniform = _random_uniform_op(p.use_input_dependent_random_seed)
        masked_portion = random_uniform(shape=(batch_size, multiplicity),
                                        minval=0.0,
                                        maxval=1.0,
                                        dtype=dtype,
                                        seed=seed_1)
        masked_frame_size = self.EinsumBBmBm(max_length, masked_portion)
        masked_frame_size = tf.cast(masked_frame_size, dtype=tf.int32)
        # Make sure the sampled length was smaller than max_ratio * length_bound.
        # Note that sampling in this way was biased
        # (shorter sequence may over-masked.)
        choose_range = tf.expand_dims(choose_range, -1)
        choose_range = tf.tile(choose_range, [1, multiplicity])
        length_bound = tf.cast(choose_range, dtype=dtype)
        length_bound = tf.cast(max_ratio * length_bound, dtype=tf.int32)
        length = tf.minimum(masked_frame_size, tf.maximum(length_bound, 1))

        # Choose starting point.
        random_start = random_uniform(shape=(batch_size, multiplicity),
                                      maxval=1.0,
                                      seed=seed_2)
        start_with_in_valid_range = random_start * tf.cast(
            (choose_range - length + 1), dtype=dtype)
        start = tf.cast(start_with_in_valid_range, tf.int32)
        end = start + length - 1

        # Shift starting and end point by small value.
        delta = tf.constant(0.1)
        start = tf.expand_dims(tf.cast(start, dtype) - delta, -1)
        start = tf.tile(start, [1, 1, mask_size])
        end = tf.expand_dims(tf.cast(end, dtype) + delta, -1)
        end = tf.tile(end, [1, 1, mask_size])

        # Construct pre-mask of shape (batch_size, multiplicity, mask_size).
        diagonal = tf.expand_dims(
            tf.expand_dims(tf.cast(tf.range(mask_size), dtype=dtype), 0), 0)
        diagonal = tf.tile(diagonal, [batch_size, multiplicity, 1])
        pre_mask = tf.cast(tf.logical_and(diagonal < end, diagonal > start),
                           dtype=dtype)

        # Sum masks with appropriate multiplicity.
        if masks_per_frame > 0:
            multiplicity_weights = tf.tile(
                tf.expand_dims(tf.range(multiplicity, dtype=dtype), 0),
                [batch_size, 1])
            multiplicity_tensor = masks_per_frame * tf.cast(choose_range,
                                                            dtype=dtype)
            multiplicity_weights = tf.cast(
                multiplicity_weights < multiplicity_tensor, dtype=dtype)
            pre_mask = self.EinsumBmtBmBt(pre_mask, multiplicity_weights)
        else:
            pre_mask = tf.reduce_sum(pre_mask, 1)
        mask = tf.cast(1.0 - tf.cast(pre_mask > 0, dtype=dtype), dtype=dtype)

        if p.fprop_dtype is not None and p.fprop_dtype != p.dtype:
            mask = tf.cast(mask, p.fprop_dtype)

        return mask
예제 #13
0
 def _iter_condition(i, unused_mat_y, unused_old_mat_y, unused_mat_z,
                     unused_old_mat_z, err, old_err):
   """This method require that we check for divergence every step."""
   return tf.logical_and(i < iter_count, err < old_err)
예제 #14
0
 def _iter_condition(i, unused_mat_m, unused_mat_h, unused_old_mat_h, error,
                     run_step):
   return tf.logical_and(
       tf.logical_and(i < iter_count, error > error_tolerance), run_step)
예제 #15
0
        def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states,
                                      num_hyps_per_beam, *args, **kwargs):
            """Wrapper for adding bias to _PreBeamSearchStateCallback.

      Biases results.log_probs towards provided encoder_outputs.targets.

      Args:
        theta: a NestedMap of parameters.
        encoder_outputs: a NestedMap computed by encoder.
        step_ids: A tensor of shape [tgt_batch, 1].
        states: A `.NestedMap` of tensors representing states that the clients
          would like to keep track of for each of the active hyps.
        num_hyps_per_beam: Beam size.
        *args: additional arguments to _PreBeamSearchStepCallback.
        **kwargs: additional arguments to _PreBeamSearchStepCallback.

      Returns:
        A tuple (results, out_states).
        results: A `.NestedMap` of beam search results.
          atten_probs:
            The updated attention probs, of shape [tgt_batch, src_len].
          log_probs:
            Log prob for each of the tokens in the target vocab. This is of
            shape
            [tgt_batch, vocab_size].
        out_states: a `.NestedMap` The updated states. The states relevant here
          are:
          time_step: A scalar indicating current step of decoder.  Must be
            provided and maintained by subclass.
          consistent: A boolean vector of shape [tgt_batch, ] which tracks
              whether each hypothesis has exactly matched
              encoder_outputs.targets
              so far.
      """
            p = self.params
            time_step = states.time_step
            bs_results, out_states = self._PreBeamSearchStepCallback(
                theta, encoder_outputs, step_ids, states, num_hyps_per_beam,
                *args, **kwargs)

            labels = encoder_outputs.targets.labels
            weights = encoder_outputs.targets.weights

            def TileForBeamAndFlatten(tensor):
                tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                tensor = tf.tile(
                    tensor,
                    [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
                tgt_batch = tf.shape(step_ids)[
                    0]  # num_hyps_per_beam*src_batch
                return tf.reshape(tensor, [tgt_batch])

            # Consistent if step_ids == labels from previous step
            # TODO(navari): Consider updating consistent only if weights > 0. Then
            # re-evaluate the need for bias_only_if_consistent=True.
            # Note that prev_label is incorrrect for step 0 but is overridden later
            prev_label = TileForBeamAndFlatten(
                tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
            is_step0 = tf.equal(time_step, 0)
            local_consistence = tf.logical_or(
                is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
            out_states.consistent = tf.logical_and(states.consistent,
                                                   local_consistence)

            # get label, weight slices corresponding to current time_step
            label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1))
            weight = TileForBeamAndFlatten(
                tf.gather(weights, time_step, axis=1))
            if p.bias_only_if_consistent:
                weight = weight * tf.cast(out_states.consistent, p.dtype)

            # convert from dense label to sparse label probs
            vocab_size = tf.shape(bs_results.log_probs)[1]
            uncertainty = tf.constant(
                1e-10,
                p.dtype)  # avoid 0 probs which may cause issues with log
            label_probs = tf.one_hot(label,
                                     vocab_size,
                                     on_value=1 - uncertainty,
                                     off_value=uncertainty /
                                     tf.cast(vocab_size - 1, p.dtype),
                                     dtype=p.dtype)  # [tgt_batch, vocab_size]
            pred_probs = tf.exp(bs_results.log_probs)

            # interpolate predicted probs and label probs
            weight = tf.expand_dims(weight, 1)
            probs = py_utils.with_dependencies([
                py_utils.assert_less_equal(weight, 1.),
                py_utils.assert_greater_equal(weight, 0.)
            ], (1.0 - weight) * pred_probs + weight * label_probs)

            bs_results.log_probs = tf.log(probs)

            return bs_results, out_states