Пример #1
0
  def _GetMask(self,
               batch_size,
               choose_range,
               mask_size,
               max_length=None,
               masks_per_frame=0.0,
               multiplicity=1,
               dtype=tf.float32,
               max_ratio=1.0):
    """Returns fixed size multi-masks starting from random positions.

    A multi-mask is a mask obtained by applying multiple masks.

    This function when max_length is given:
      1) Sample random mask lengths less than max_length with shape
         (batch_size, multiplicity).
      2) Truncate lengths to a max of (choose_range * max_ratio),
         so that each mask is fully contained within the corresponding sequence.
      3) Random sample start points of shape (batch_size, multiplicity)
         with in (choose_range - lengths).
      4) For each batch, multiple masks (whose number is given by the
         multiplicity) are constructed.
      5) Return a mask of shape (batch_size, mask_size) where masks are
         obtained by composing the masks constructed in step 4).
         If masks_per_frame > 0, the number is given by
         min(masks_per_frame * choose_range, multiplicity).
         If not, all the masks are composed. The masked regions are set to zero.

    This function when max_length is not given:
      1) Sample random mask lengths less than (choose_range * max_ratio)
         with shape (batch_size, multiplicity).
      2) Proceed to steps 3), 4) and 5) of the above.

    Args:
      batch_size: Batch size. Integer number.
      choose_range: Range within which the masked entries must lie. Tensor of
        shape (batch_size,).
      mask_size: Size of the mask. Integer number.
      max_length: Maximum number of allowed consecutive masked entries. Integer
        number or None.
      masks_per_frame: Number of masks per frame. Float number. If > 0, the
        multiplicity of the mask is set to be masks_per_frame * choose_range.
      multiplicity: Maximum number of total masks. Integer number.
      dtype: Data type.
      max_ratio: Maximum portion of the entire range allowed to be masked. Float
        number.

    Returns:
      mask: a fixed size multi-mask starting from a random position with shape
      (batch_size, mask_size).
    """
    p = self.params
    # Non-empty random seed values are only used for testing
    # seed_1 and seed_2 are set separately to avoid correlation of
    # mask size and mask position.
    if p.random_seed:
      seed_1 = p.random_seed + 1
      seed_2 = 2 * p.random_seed
    else:
      seed_1 = p.random_seed
      seed_2 = p.random_seed
    # Sample lengths for multiple masks.
    if max_length and max_length > 0:
      max_length = tf.broadcast_to(tf.cast(max_length, dtype), (batch_size,))
    else:
      max_length = tf.cast(choose_range, dtype=dtype) * max_ratio
    masked_portion = tf.random.uniform((batch_size, multiplicity),
                                       minval=0.0,
                                       maxval=1.0,
                                       dtype=dtype,
                                       seed=seed_1)
    masked_frame_size = tf.einsum('b,bm->bm', max_length, masked_portion)
    masked_frame_size = tf.cast(masked_frame_size, dtype=tf.int32)
    # Make sure the sampled length was smaller than max_ratio * length_bound.
    # Note that sampling in this way was biased
    # (shorter sequence may over-masked.)
    choose_range = tf.expand_dims(choose_range, -1)
    choose_range = tf.tile(choose_range, [1, multiplicity])
    length_bound = tf.cast(choose_range, dtype=dtype)
    length_bound = tf.cast(max_ratio * length_bound, dtype=tf.int32)
    length = tf.minimum(masked_frame_size, tf.maximum(length_bound, 1))

    # Choose starting point.
    random_start = tf.random.uniform((batch_size, multiplicity),
                                     maxval=1.0,
                                     seed=seed_2)
    start_with_in_valid_range = random_start * tf.cast(
        (choose_range - length + 1), dtype=dtype)
    start = tf.cast(start_with_in_valid_range, tf.int32)
    end = start + length - 1

    # Shift starting and end point by small value.
    delta = tf.constant(0.1)
    start = tf.expand_dims(tf.cast(start, dtype) - delta, -1)
    start = tf.tile(start, [1, 1, mask_size])
    end = tf.expand_dims(tf.cast(end, dtype) + delta, -1)
    end = tf.tile(end, [1, 1, mask_size])

    # Construct pre-mask of shape (batch_size, multiplicity, mask_size).
    diagonal = tf.expand_dims(
        tf.expand_dims(tf.cast(tf.range(mask_size), dtype=dtype), 0), 0)
    diagonal = tf.tile(diagonal, [batch_size, multiplicity, 1])
    pre_mask = tf.cast(
        tf.logical_and(diagonal < end, diagonal > start), dtype=dtype)

    # Sum masks with appropriate multiplicity.
    if masks_per_frame > 0:
      multiplicity_weights = tf.tile(
          tf.expand_dims(tf.range(multiplicity, dtype=dtype), 0),
          [batch_size, 1])
      multiplicity_tensor = masks_per_frame * tf.cast(choose_range, dtype=dtype)
      multiplicity_weights = tf.cast(
          multiplicity_weights < multiplicity_tensor, dtype=dtype)
      pre_mask = tf.einsum('bmt,bm->bt', pre_mask, multiplicity_weights)
    else:
      pre_mask = tf.reduce_sum(pre_mask, 1)
    mask = tf.cast(1.0 - tf.cast(pre_mask > 0, dtype=dtype), dtype=dtype)

    if p.fprop_dtype is not None and p.fprop_dtype != p.dtype:
      mask = tf.cast(mask, p.fprop_dtype)

    return mask
Пример #2
0
    def __init__(self,
                 learning_rate,
                 momentum=0.0,
                 initial_accumulator_value=0.0,
                 start_preconditioning_steps=1000,
                 statistics_computation_frequency=1,
                 matrix_epsilon=1e-6,
                 synchronous_preconditioning=False,
                 second_moment_averaging=1.0,
                 fallback_to_diagonal_dim=4096,
                 max_any_dim=6656,
                 block_size=4096,
                 block_partition_threshold_size=1000000,
                 global_step=None,
                 exponent_multiplier=1.0,
                 name="DistributedShampoo"):
        """Construct a DistributedShampoo optimizer.

    Args:
      learning_rate: A `Tensor` or a floating point value.  The learning rate.
      momentum: A `Tensor` or a floating point value. Momentum is not applied to
        sparse updates.
      initial_accumulator_value: A floating point value.
      start_preconditioning_steps: A int32 value which indicates when to start
        preconditioning.
      statistics_computation_frequency: A int32 step value which indicates how
        often to compute statistics for preconditioning.
      matrix_epsilon: An epsilon regularizer to make the matrices positive
        definite.
      synchronous_preconditioning: Whether to run preconditioning synchronously.
      second_moment_averaging: 1.0 means sum of gradients squares, while less
        than 1.0 switches to RMSProp style exponential moving averages of the
        second moments.
      fallback_to_diagonal_dim: Fallback to diagonal version of AFMA if the any
        of the dimension is larger than fallback_to_diagonal_dim.
      max_any_dim: If maximum value for any dimension is greater than this value
        we skip preconditioning and fall back to the diagonal.
      block_size: Dimension of the partitioned tensors.
      block_partition_threshold_size: Partitions diemnsions beyond this size.
      global_step: Global step for training.
      exponent_multiplier: A multiplier 'e` for the exponent for the inverse
        calculation. e * -1/(2*rank). Only applies when calculating inverses
        through svd.
      name: Optional name prefix for the operations created when applying
        gradients.
    """
        super(DistributedShampoo, self).__init__(False, name)
        self._learning_rate = learning_rate
        self._momentum = momentum
        self._initial_accumulator_value = initial_accumulator_value
        self._start_preconditioning_steps = start_preconditioning_steps
        self._matrix_epsilon = matrix_epsilon
        self._synchronous_preconditioning = synchronous_preconditioning
        self._second_moment_averaging = second_moment_averaging
        self._fallback_to_diagonal_dim = fallback_to_diagonal_dim
        self._max_any_dim = max_any_dim
        self._block_size = block_size
        # NOTE: On XLA - int64 is not handled properly.
        if global_step is not None:
            self._global_step = tf.cast(tf.identity(global_step), tf.int32)
        else:
            self._global_step = tf.cast(
                tf.identity(tf.train.get_or_create_global_step()), tf.int32)
        self._run_nondiagonal_update = tf.greater_equal(
            self._global_step, self._start_preconditioning_steps)
        start_steps_f = tf.cast(self._start_preconditioning_steps, tf.float32)
        global_step_f = tf.cast(self._global_step, tf.float32)
        self._run_nondiagonal_update_warmup = tf.minimum(
            1.0,
            tf.maximum((global_step_f - start_steps_f) / start_steps_f, 0.0))
        # Computes statistics every K steps.
        self._statistics_computation_frequency = statistics_computation_frequency
        self._run_statistics_computation = tf.equal(
            tf.math.floormod(self._global_step,
                             self._statistics_computation_frequency), 0)
        # All vars that are preconditioned.
        self._all_vars_for_preconditioning = []
        self._exponent_multiplier = exponent_multiplier
        self._partition_info = PartitionConfig(block_partition_threshold_size,
                                               block_size)
        self._partitioner_metadata = {}
Пример #3
0
    def _GetWarpMatrix(self,
                       batch_size,
                       choose_range,
                       matrix_size,
                       global_seed,
                       max_warp_frames=None,
                       dtype=tf.float32,
                       max_ratio=1.0):
        """Returns warp matrices starting from random positions.

    In this function when max_warp_frames != None:
      1) Sample random warp displacements from the interval
         [-max_warp_frames, max_warp_frames) to yield shift tensor
         with shape (batch_size,).
      2) Truncate lengths to a maximum magnitude of (choose_range * max_ratio),
         so that each shift is fully contained within the
         corresponding sequence.
      3) Random sample origin points of shape (batch_size, multiplicity)
         with in [shift, choose_range - shift).
      4) Return a batch of 1-D linear maps that fix the boundary points and
         shift the origin point by the shift.

    When max_warp_frames == None:
      1) Sample random warp displacements with magnitudes less than
         (choose_range * max_ratio) to yield shift tensor with
         shape (batch_size,).
      2) Proceed through steps 3), 4).

    Args:
      batch_size: Batch size. Integer number.
      choose_range: Range within which the warp reference points must lie.
        Tensor of shape (batch_size,).
      matrix_size: Dimension of vector space warp matrix is applied to. Integer
        number.
      global_seed: an integer seed tensor for stateless random ops.
      max_warp_frames: Upper-bound on the warp distance. Integer or None.
      dtype: Data type.
      max_ratio: Maximum ratio between the shift distance and choose_range.
        Float number.

    Returns:
      warp_matrix: An array of fixed size warp matrices with shape
      (batch_size, matrix_size, matrix_size).
    """
        p = self.params
        # Non-empty random seed values are only used for testing or when using
        # stateless random ops. seed_3, seed_4, and seed_5 are set separately to
        # avoid correlation of warp magnitude and origin position.
        if p.use_input_dependent_random_seed:
            seed_3 = global_seed + 3
            seed_4 = global_seed + 4
            seed_5 = global_seed + 5
        elif p.random_seed:
            seed_3 = p.random_seed - 1
            seed_4 = p.random_seed - 1
            seed_5 = 2 * p.random_seed + 1
        else:
            seed_3 = p.random_seed
            seed_4 = p.random_seed
            seed_5 = p.random_seed

        choose_range_dtype = tf.cast(choose_range, dtype=dtype)
        length_upper_bound = tf.cast(max_ratio * choose_range_dtype,
                                     dtype=tf.int32)
        # Set shift length.

        random_uniform = _random_uniform_op(p.use_input_dependent_random_seed)

        if max_warp_frames and max_warp_frames > 0:
            shift = random_uniform(shape=(batch_size, ),
                                   minval=-1 * max_warp_frames,
                                   maxval=max_warp_frames + 1,
                                   dtype=tf.int32,
                                   seed=seed_3)
        else:
            random_ratio = random_uniform(shape=(batch_size, ),
                                          minval=-1.0,
                                          maxval=1.0,
                                          dtype=dtype,
                                          seed=seed_4)
            shift = tf.cast(
                random_ratio * tf.cast(length_upper_bound, dtype=dtype),
                tf.int32)
        # Make sure the sampled length was smaller than max_ratio * length_bound.
        # Note that sampling in this way is biased.
        # (Shorter sequence may over-masked.)
        final_shift = tf.maximum(-length_upper_bound,
                                 tf.minimum(shift, length_upper_bound))
        # Choose origin anchor point.
        mid_range = tf.cast(choose_range, dtype=tf.int32)
        mid_range = tf.maximum(choose_range - 2, 0)
        random_origin = random_uniform(shape=(batch_size, ),
                                       maxval=1.0,
                                       seed=seed_5)
        origin_with_in_valid_range = random_origin * tf.cast(mid_range,
                                                             dtype=dtype)
        origin = tf.cast(origin_with_in_valid_range, tf.int32) + 1
        # Set destination point of the origin anchor point under the warp map.
        destination = origin + final_shift
        # Cast origin and destination.
        origin = tf.cast(origin, dtype=dtype)
        destination = tf.cast(destination, dtype=dtype)

        return self._ConstructWarpMatrix(batch_size=batch_size,
                                         matrix_size=matrix_size,
                                         origin=origin,
                                         destination=destination,
                                         choose_range=choose_range_dtype,
                                         dtype=dtype)
Пример #4
0
    def _ConstructWarpMatrix(self, batch_size, matrix_size, origin,
                             destination, choose_range, dtype):
        """Returns warp matrices according to origin, destination and choose_range.

    This function constructs a batch of warp matrices which maps the batch
    of origin points to the batch of destination points with fixed boundary
    coordinates at 0 and choose_range.

    The warping function, defined by the origin anchor point `origin`,
    the destination of the origin anchor point `destination` and the
    length of the domain in the warping axis `choose_range` is a piecewise
    linear map that fixes the points 0 and `choose_range` and maps
    `origin` to `destination`.

    For the warping matrix to be non-singular, destination must lie in the
    range 1<= destination <= choose_range - 1, so a destination
    out of this range is adjusted to be in this range before the warping
    matrix is constructed.

    The warping map can be explicitly written by first defining the slopes:
      1) slope_0 = origin / destination.
      2) slope_1 = (choose_range - origin) / (choose_range - destination).
      3) slope_2 = 1.0.

    Then the origin point orig_i of the mapped coordinate i is given by:
      1) i < destination: orig_i = slope_0 * i.
      2) destination <= i < choose_range:
         orig_i = slope_1 * i - (slope_1 - slope_0) * destination.
      3) i >= choose_range: orig_i = i.

    Denoting n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by:
      1) j = n_i: 1 - n_i + orig_i.
      2) j = n_i - 1: n_i - orig_i.
      3) Otherwise: 0.

    Applying the warp matrix to an array of pixels, i.e.,
    warped_pixel[i] = sum_j warp[i][j] * pixel[j], one would get
    warped_pixel[i] = (n_i-orig_i) pixel[n_i-1] + (1-n_i+orig_i) pixel[n_i].

    Args:
      batch_size: Batch size. Integer number.
      matrix_size: Dimension of the vector space the warp matrix is applied to.
        Integer number.
      origin: Origin anchor point for warping. Tensor of shape (batch_size,) and
        data type dtype.
      destination: Destination of the origin anchor point upon warping. Tensor
        of shape (batch_size,) and data type dtype.
      choose_range: Range within which the warp reference points must lie.
        Tensor of shape (batch_size,) data type dtype.
      dtype: Data type of origin, destination, choose_range and the output warp
        matrix.

    Returns:
      warp_matrix: An array of fixed size warp matrices with shape
      (batch_size, matrix_size, matrix_size).
    """
        p = self.params

        # Entries of destination must be in the range
        # 1 <= destination <= choose_range - 1
        # for warp matrix to have non-singular values.
        destination = tf.minimum(tf.maximum(destination, 1.0),
                                 choose_range - 1.0)

        # Construct piece-wise linear function fixing boundary points
        # specified by zero, choose_range and matrix size and maps
        # the origin anchor point to the destination.
        destination_bc = tf.broadcast_to(destination,
                                         (matrix_size, batch_size))
        destination_bc = tf.transpose(destination_bc)
        choose_range_bc = tf.broadcast_to(choose_range,
                                          (matrix_size, batch_size))
        choose_range_bc = tf.transpose(choose_range_bc)

        # Slopes of piece-wise linear function.
        slope_0 = origin / destination
        slope_1 = (choose_range - origin) / (choose_range - destination)
        slope_2 = 1.0

        # x is a batch of origin matrices.
        # The origin matrix is the matrix such that
        # origin[i][j] = Origin coordinate of coordinate i for the warp map.
        # Denoting the destination of the origin anchor point in the
        # warp map as "dest," the origin coordinate of point i is given by:
        # 1) i < dest: slope_0 * i.
        # 2) dest <= i < choose_range: slope_1 * i - (slope_1 - slope_0) * dest.
        # 3) i >= choose_range: i.
        x = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype),
                            (batch_size, matrix_size))
        x = (self.EinsumBBmBm(slope_0, x) + self.EinsumBBmBm(
            slope_1 - slope_0, tf.nn.relu(x - destination_bc)) +
             self.EinsumBBmBm(slope_2 - slope_1,
                              tf.nn.relu(x - choose_range_bc)))
        x = tf.broadcast_to(x, (matrix_size, batch_size, matrix_size))
        x = tf.transpose(x, perm=[1, 2, 0])

        # y is a batch of coordinate matrices.
        # A coordinate matrix is a matrix such that
        # coordinate[i][j] = j.
        y = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype),
                            (batch_size, matrix_size, matrix_size))
        # Warp matrix is obtained by applying hat function element-wise to (x-y).
        # Denoting the origin point of i under the warp map as orig_i,
        # and n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by:
        # 1) j = n_i: 1 - n_i + orig_i.
        # 2) j = n_i - 1: n_i - orig_i.
        # 3) Otherwise: 0.
        # Applying the warp matrix to pixels, i.e.,
        # warped_pixel[i] = sum_j warp[i][j] * original_pixel[j], one would get
        # warped_pixel[i] = (n_i - orig_i) * original_pixel[n_i-1]
        #                   + (1 - n_i + orig_i) * original_pixel[n_i].
        warp_matrix = x - y
        warp_matrix = _hat(warp_matrix)
        if p.fprop_dtype is not None and p.fprop_dtype != dtype:
            warp_matrix = tf.cast(warp_matrix, p.fprop_dtype)

        return warp_matrix
Пример #5
0
    def FProp(self, theta, inputs, paddings):
        """Applies causal pooling to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor. It is expected to be of shape [batch,
        time].

    Returns:
      outputs, out_paddings pair.
       - outputs: has the same shape as inputs.
       - out_paddings: has the same tshape as paddings.
    """

        p = self.params
        if p.left_context == -1:
            if p.pooling_type == 'AVG':
                cumulative_sum = tf.math.cumsum(inputs, axis=1)
                cumulative_count = 1.0 + tf.range(py_utils.GetShape(inputs)[1],
                                                  dtype=p.dtype)
                cumulative_mean = cumulative_sum / cumulative_count[
                    tf.newaxis, :, tf.newaxis, tf.newaxis]
                cumulative_mean *= 1.0 - paddings[..., tf.newaxis, tf.newaxis]
                return cumulative_mean, paddings
            else:
                raise NotImplementedError(
                    'Cumulative max pooling not implemented.')

        window_size = p.left_context
        left_pad_size = window_size - 1
        large_negative = p.dtype.max * tf.constant(-0.7, dtype=p.dtype)
        # For max pooling, use a large negative padding value such that the max
        # element is almost always from a non-padding position.
        pad_value = 0 if p.pooling_type == 'AVG' else large_negative
        inputs = tf.pad(inputs, [[0, 0], [left_pad_size, 0], [0, 0], [0, 0]],
                        constant_values=pad_value)

        out_feature = tf.nn.pool(inputs,
                                 window_shape=(window_size, 1),
                                 pooling_type=p.pooling_type,
                                 padding='VALID')

        if p.pooling_type == 'AVG':
            # Count the fraction of non-padding elements inside each pooling window.
            max_seq_len = py_utils.GetShape(paddings)[1]
            num_non_padded_elements = tf.range(1,
                                               1 + max_seq_len,
                                               dtype=p.dtype)
            num_non_padded_elements = tf.minimum(num_non_padded_elements,
                                                 tf.cast(window_size, p.dtype))
            non_padded_ratio = num_non_padded_elements / tf.cast(
                window_size, p.dtype)
            # Divide by non-padding ratios to eliminate the effect of padded zeros.
            out_feature *= tf.math.reciprocal_no_nan(
                non_padded_ratio[tf.newaxis, :, tf.newaxis, tf.newaxis])
        out_feature *= 1.0 - paddings[..., tf.newaxis, tf.newaxis]
        return out_feature, paddings
Пример #6
0
  def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx):
    """Loop body for farthest point sampler."""

    def _GetRandomRealPoint():
      """Select the first point.

      For the first point, we want any random real (non padded) point, so we
      create a random values per point, and then set all padded ones to
      some large value (more than the maxval). We then take the min per batch
      element to get the first points.

      Returns:
        Tensor containing the index of a random point selected for each example
        in the batch.
      """
      random_values = tf.random.uniform((batch_size, num_points),
                                        minval=0,
                                        maxval=1,
                                        dtype=tf.float32,
                                        seed=random_seed)
      random_values = tf.where(
          tf.equal(padding, 0.0), random_values, padding * 10)
      return tf.argmin(random_values, axis=1, output_type=tf.int32)

    def _GetFurthestPoint():
      """Get point that is furthest from those already selected.

      We also bias the sampling towards real points by setting the distance
      to padded points negative until we are out of real points.

      Returns:
        Tensor containing the index of the next farthest point selected for each
        example in the batch.
      """
      # Set padded points distance to negative so they aren't selected.
      padding_masked_distance_to_selected = tf.where(
          tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones(
              (batch_size, num_points), dtype=tf.float32))
      # But only do this when we still have valid points left.
      padding_masked_distance_to_selected = tf.where(
          tf.less(curr_idx, num_valid_points),
          padding_masked_distance_to_selected, distance_to_selected)
      return tf.argmax(
          padding_masked_distance_to_selected, axis=-1, output_type=tf.int32)

    def _GetSeededPoint():
      """Select a seeded point.

      Seeded points are assumed to be at the beginning of the original points.

      Returns:
        Tensor containing the index of the next seeded point to select for each
        example in the batch.
      """
      return tf.ones((batch_size,), dtype=tf.int32) * curr_idx

    # Select indices for this loop iteration.
    def _Seeded():
      return tf.cond(
          tf.less(curr_idx, num_seeded_points), _GetSeededPoint,
          _GetFurthestPoint)

    def _Real():
      return tf.cond(
          tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint)

    new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real)
    sampled_idx = sampled_idx.write(curr_idx, new_selected)

    # Extract the distance to the latest point selected to update
    # distance_to_selected.
    new_selected_gather_idx = tf.stack([tf.range(batch_size), new_selected],
                                       axis=1)
    if precomputed_squared_distance is not None:
      new_distance = tf.gather_nd(precomputed_squared_distance,
                                  new_selected_gather_idx)
    else:
      new_points = tf.reshape(
          tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims])
      new_distance = tf.reshape(
          SquaredDistanceMatrix(points, new_points), [batch_size, num_points])

    is_newly_closest = tf.less(new_distance, distance_to_selected)
    distance_to_selected = tf.minimum(distance_to_selected, new_distance)

    # Track the index to the closest selected point.
    new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points])
    closest_idx = tf.cond(
        tf.equal(curr_idx, 0),
        # At the first loop iteration, the init points are the closest.
        lambda: new_selected_tiled,
        # Otherwise, update with the new points based on the distances.
        lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx))
    return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx
Пример #7
0
 def Value(self):
     return tf.minimum(self.rampup_schedule.Value(),
                       self.decay_schedule.Value())
Пример #8
0
 def Value(self, step=None):
   return tf.minimum(
       self.rampup_schedule.Value(step), self.decay_schedule.Value(step))
Пример #9
0
    def FProp(self,
              theta,
              x,
              x_paddings=None,
              eos_id=1,
              force_sample_last_token=True):
        """Applies SymbolInsertionLayer.

    We take in a `x`, which represents the groundtruth sequence (i.e., English
    sequence). We return a sampled rollin (observed) canvas (i.e., random subset
    of the English sequence), as well as the target (indices) for an
    insertion-based model (i.e., the targets given the random observed subset).

    Args:
      theta: Ignored, this can be None.
      x: The symbol ids of shape `[batch_size, time_dim]`.
      x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where
        0 is valid and 1 is invalid.
      eos_id: The <eos> token id to represent end-of-slot.
      force_sample_last_token: Set True to force sample the last token of `x`.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be
          equal.
        - canvas_indices: The canvas indices (into `x`).
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices of shape [num_targets, 3].
          `num_targets` is the number of total targets in the entire batch.
          [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2]
          captures the token. Each row [batch, slot, vocab] represents the
          indices of the target -- i.e., the batch, slot and vocab combination
          of the target. Typical usage of these indices is to tf.gather_nd
          the log-probs (from the softmax layer).
        - target_weights: The target weights.

    Raises:
      ValueError: If invalid params.
    """
        p = self.params

        batch_size = py_utils.GetShape(x)[0]
        time_dim = py_utils.GetShape(x)[1]

        if x_paddings is None:
            x_paddings = tf.zeros([batch_size, time_dim], tf.float32)

        oracle_policy = p.oracle_policy
        rollin_policy = (oracle_policy
                         if p.rollin_policy == 'oracle' else p.rollin_policy)

        if rollin_policy != 'uniform':
            raise ValueError('Unknown or unsupported rollin policy: %s' %
                             rollin_policy)
        if oracle_policy != 'uniform':
            raise ValueError('Unknown or unsupported oracle policy: %s' %
                             oracle_policy)

        x_len = tf.to_int32(tf.round(tf.reduce_sum(1 - x_paddings, 1)))

        # Compute the desired length per example in the batch.
        ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed)
        if force_sample_last_token:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32),
                x_len - 1) + 1
        else:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32),
                x_len)
        # Compute the maximum length across the batch.
        c_len_max = tf.reduce_max(c_len)

        # Grab subset of random valid indices per example.
        z_logits = tf.cast(
            tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1),
            tf.float32) * -1e9
        if force_sample_last_token:
            # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can
            # accomplish this by add +LARGE_NUMBER to the logits.
            z_logits += tf.cast(
                tf.equal(tf.expand_dims(tf.range(time_dim), 0),
                         tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9
        # Gumbel-max trick to sample (we only sample valid positions per sample in
        # the batch).
        z = -tf.math.log(-tf.math.log(
            tf.random.uniform([batch_size, time_dim], seed=p.random_seed)))
        unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim)

        # Trim everything > c_len_max.
        c_indices = c_indices[:, :c_len_max]

        # Invalidate any indices >= c_len, we use the last index as the default
        # invalid index.
        c_indices = tf.where(
            tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1),
            c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1))

        # Materialize the canvas.
        c_indices = tf.sort(c_indices)
        c = tf.gather_nd(
            x,
            tf.stack([
                tf.reshape(
                    tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                            [1, c_len_max]), [-1]),
                tf.reshape(c_indices, [-1])
            ], 1))
        c = tf.reshape(c, [batch_size, c_len_max])

        # Compute the paddings.
        c_paddings = 1 - tf.sequence_mask(
            c_len, c_len_max, dtype=x_paddings.dtype)
        c *= tf.cast(1 - c_paddings, tf.int32)

        indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, c_len_max]), [batch_size * c_len_max, 1]),
            tf.reshape(c_indices, [batch_size * c_len_max, 1])
        ], 1)
        x_token_is_observed = tf.scatter_nd(
            indices, tf.ones([batch_size * c_len_max], tf.int32),
            py_utils.GetShape(x))
        # `x_segments` captures which slot each `x` belongs to (both observed and
        # tokens that need to be observed).
        x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True)

        x_token_is_observed = tf.cast(x_token_is_observed, tf.bool)
        prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1],
                                          [[0, 0], [1, 0]],
                                          constant_values=True)
        x_token_is_observed = tf.reshape(x_token_is_observed, [-1])
        prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1])
        x_is_valid = tf.cast(1 - x_paddings, tf.bool)
        x_is_valid = tf.reshape(x_is_valid, [-1])

        # Remap all the observed to <eos>, note some of these need a zero weight
        # (or else there would be <eos> and valid token in the same slot).
        target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32)
        target_indices = tf.where(
            x_token_is_observed,
            tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices)

        # TODO(williamchan): We give uniform 1.0 weight, however, math suggests
        # we may want to weigh this term by the original sequence length.
        target_weights = tf.ones_like(target_indices, tf.float32)

        # We need to set all the weights for <eos> which actually have valid tokens
        # in the slot to zero.
        target_weights = tf.where(
            x_token_is_observed & ~prev_x_token_is_observed,
            tf.zeros_like(target_weights), target_weights)

        # TODO(williamchan): Consider dropping the entries w/ weight zero.

        # Add the batch and slot indices.
        target_indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, time_dim]), [batch_size * time_dim, 1]),
            tf.reshape(x_segments, [-1, 1]), target_indices
        ], 1)

        # Select only the valid indices. The selected valid ones include slots w/
        # <eos>.
        target_indices = target_indices[x_is_valid]
        target_weights = target_weights[x_is_valid]

        return py_utils.NestedMap(canvas=c,
                                  canvas_indices=c_indices,
                                  canvas_paddings=c_paddings,
                                  target_indices=target_indices,
                                  target_weights=target_weights)
Пример #10
0
def flat_beam_search(batch_size,
                     beam_size,
                     max_steps,
                     dec_callback,
                     dec_state,
                     bos_id=1,
                     eos_id=2,
                     length_norm_alpha=0.8,
                     beam_gap=3.0,
                     top_k_fn=tf.math.top_k,
                     prefix=None,
                     prefix_len=None,
                     fprop_dtype=tf.float32,
                     ext_size=0,
                     nbest_size=None,
                     debug=True):
    """Flat beam search.

  Args:
    batch_size: batch size
    beam_size: beam size limit in number of hyps
    max_steps: max steps
    dec_callback: decoder callback (see above)
    dec_state: decoder state
    bos_id: <s> token id
    eos_id: </s> token id
    length_norm_alpha: length normalization parameter
    beam_gap: early stopping threshold; None to disable
    top_k_fn: top_k function to call
    prefix: (optional) int32 tensor [batch_size, prefix_max]
    prefix_len: (optional) int32 tensor [batch_size]
    fprop_dtype: fprop dtype
    ext_size: int >= beam_size, extension buffer size
    nbest_size: number of returned hyps, default is beam_size
    debug: log intermediate vlaues with tpu_summary.tensor()

  Returns:
    (loop_vars, dec_state, nbest) where
    nbest = (topk_ids, topk_len, topk_score)
  """
    assert beam_size > 0
    assert batch_size > 0
    assert max_steps > 0

    buf_size = beam_size * max_steps
    output_len = max_steps

    if prefix is None:
        assert prefix_len is None
        # Create prefix of start tokens.
        prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        prefix += tf.one_hot(beam_size - 1, beam_size, dtype=tf.int32) * bos_id
        prefix_len = tf.ones([batch_size], dtype=tf.int32)
    else:
        assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape)
        assert int(prefix_len.shape[0]) == batch_size, (batch_size,
                                                        prefix_len.shape)
        output_len += int(prefix.shape[1])

    if debug:
        tpu_summary.tensor('prefix', prefix)
        tpu_summary.tensor('prefix_len', prefix_len)

    with tf.name_scope('init_state'):
        t = tf.constant(0)
        tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_id += bos_id
        tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size),
                               buf_size,
                               dtype=fprop_dtype)
        hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype)
        # penalize all hyps except the first
        hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5,
                             dtype=fprop_dtype)
        nbest_size = nbest_size or beam_size
        nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype)
        nbest_score -= 1e9
        nbest_score_norm = nbest_score
        nbest_mask = tf.zeros([batch_size, nbest_size, buf_size],
                              dtype=fprop_dtype)

    with tf.name_scope('init_ext'):
        # Initialize the extension buffer.
        #
        # Extension buffer stores a (potentially large) set of 'extensions',
        # which consist of a hypothesis (represented by ext_mask) and next token
        # (represented by ext_id). At each decoder iteration, top_k extensions
        # from each hypothesis are added to the buffer and sorted by score.
        #
        # Then top beam_size extensions are removed from the buffer and used
        # in the next decoder iteration. And top 'ext_size' remaining extensions
        # are carried over to be possibly evaluated at a later step.
        #
        # As a result of this manipulation, the decoder is no longer restricted
        # to always compare hyps of the same token length at each iteration.
        # In particular, for a fixed length N it can generate more than beam_size
        # terminated hyps.
        #
        # Setting ext_size = 0 disables this feautre.
        if ext_size:
            ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32)
            ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype)
            ext_score -= 1e9
            ext_mask = tf.zeros([batch_size, ext_size, buf_size],
                                dtype=fprop_dtype)
        else:
            ext_size = ext_id = ext_score = ext_mask = 0

    with tf.name_scope('init_prefix'):
        # rename prefix->pfx for shorter variables
        pfx = tf.cast(prefix, tf.int32)
        pfx_len = tf.cast(prefix_len, tf.int32)
        del prefix, prefix_len
        # Before the first call to dec_callback() the prefix shall be packed into
        # the tgt_id buffer as follows:
        #
        # [ - - - - - - P P P P P P P* - - - ]   ^
        # [ - - P P P P P P P P P P P* - - - ]   | batch
        # [ - - - - - - - - - - - P P* - - - ]   V
        # |<---- prefix len ---->  |<-- beam -->
        #
        # The last meaningful token in the prefix (P*)
        # must be located at the same position in all batch rows.
        #
        # We then make one dec_callback() with full prefix (minus P*)
        # which will populate the initial dec_state
        # (for transformer -- self-attention key/value cache)
        #
        # The last block [batch, beam] then becomes the first tgt_id for the loop.
        pfx_max = int(pfx.shape[1])
        pfx_mul = pfx_max // beam_size
        assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size)
        pfx_time = tf.range(pfx_max)
        pfx_indexes = pfx_time - pfx_max + tf.expand_dims(pfx_len - 1, 1)
        pfx_pad = tf.cast(tf.greater_equal(pfx_indexes, 0),
                          tf.int32)  # Exclude final pfx token.
        pfx_id = tf.roll(pfx, shift=1, axis=-1) * pfx_pad
        pfx_last = pfx[:, -1]

        buf_time = tf.range(buf_size)
        pfx_time_mask = tf.cast(
            tf.less_equal(tf.expand_dims(buf_time, 0),
                          tf.expand_dims(pfx_time, 1)), fprop_dtype)
        pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype),
                             pfx_time_mask)
        # Remove padding.
        assert buf_size > pfx_max
        pfx_pad_long = tf.pad(pfx_pad, [(0, 0), (0, buf_size - pfx_max)],
                              constant_values=1)
        pfx_mask *= tf.cast(tf.expand_dims(pfx_pad_long, axis=1), tf.float32)
        pfx_segment_id = pfx_pad
        pfx_pos = pfx_indexes * pfx_pad

        if debug:
            tpu_summary.tensor('pfx_id', pfx_id)
            tpu_summary.tensor('pfx_len', pfx_len)
            tpu_summary.tensor('pfx_pos', pfx_pos)
            tpu_summary.tensor('pfx_last', pfx_last)

        # Now call decoder with prefix minus P*:
        # 'dec_state' now shall contain the key/value cache for prefix tokens
        # (for transformer models), and 'logits' we can either discard or
        # roll into the initial hyp_score. Discard is simpler.
        with tf.name_scope('prefix_fprop'):
            # TODO(krikun): remove extra type checks
            assert (pfx_id.dtype == tf.int32), (pfx_id.dtype)
            assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype)
            assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype)
            assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype)
            assert (t.dtype == tf.int32), (t.dtype)
            logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos,
                                             pfx_mask, dec_state, t)
            del logits

        # Now construct the initial state for the rest of the beam search loop.
        # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape
        # 'tgt_pos' is different for each batch row and is equal to prefix_len
        # 'tgt_segment_id' always 1 (no packing)
        # 'hyp_score' is 0 for beam=0 and negative for beam>=1
        tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            pfx_last, 1)
        tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            (pfx_len - 1), 1)
        hyp_score = tf.zeros(
            [batch_size, beam_size], dtype=fprop_dtype) - tf.cast(
                tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype)

        # TODO(krikun) Here we make initial 't' constant and determined by the
        # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic
        # as t ~  max(pfx_len) / beam_size and this will more steps for beam search
        # however 'max' results in a very slow all-to-all for 'max' on 16x16
        # and variable number of decoder steps may result in bad latency.
        t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32)

        # Initial tgt_mask is such that each token P* has attention on itself
        # (as usual) and on all prefix tokens before it, which are not padding.
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.cast(
            tf.expand_dims(
                tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1),
            fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                               buf_size,
                               dtype=fprop_dtype)

        if debug:
            tpu_summary.tensor('tgt_id', tgt_id)
            tpu_summary.tensor('tgt_pos', tgt_pos)
            tpu_summary.tensor('tgt_mask', tgt_mask)
            tpu_summary.tensor('t', t)

    with tf.name_scope('init_hist'):
        # h_tgt_id is used to recover topk_ids from nbest_mask
        h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps)
        h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps)

        # When non-trivial prefix is present we also write prefix ids to
        # h_tgt_id so that the full sequence including prefix can be recovered
        # by unmask() below.  When prefix is empty, pfx_id shape is [batch, 0]
        # and the loop below becomes a no-op.
        # TODO(krikun): maybe a tf.while_loop is more appropriate here.
        for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)):
            h_tgt_id = h_tgt_id.write(i, x_i)
        for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)):
            h_tgt_pos = h_tgt_pos.write(i, x_i)

        hist = (h_tgt_id, h_tgt_pos)
        tf.logging.info('hist=%r', hist)

    nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm)
    tf.logging.info('nbest_hyps=%r', nbest_hyps)

    ext = (ext_id, ext_score, ext_mask)
    tf.logging.info('ext=%r', ext)

    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)
    tf.logging.info('loop_vars=%r', loop_vars)

    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state

    def loop_cond(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        if beam_gap is None:
            (t, _, _, _, _, _, _, _) = loop_vars
            return t < max_steps
        else:
            (t, _, _, _, _, nbest_hyps, _, _) = loop_vars
            (_, nbest_score, _) = nbest_hyps
            # stop early if all current hyps are significantly worse than nbest
            diff = tf.reduce_min(
                tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1))
            return tf.math.logical_and(t < max_steps, diff < beam_gap)

    with tf.name_scope('flat_beam_search_loop'):
        (loop_vars, dec_state) = tf.while_loop(loop_cond,
                                               loop_step,
                                               loop_vars=(loop_vars,
                                                          dec_state),
                                               back_prop=False,
                                               swap_memory=False,
                                               maximum_iterations=max_steps)

    # flatten all tensorarrays into tensors
    (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
     hist) = loop_vars
    (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps
    (h_tgt_id, h_tgt_pos) = hist
    h_tgt_id = h_tgt_id.stack()
    h_tgt_pos = h_tgt_pos.stack()
    hist = (h_tgt_id, h_tgt_pos)
    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)

    # recover topk_ids from nbest_mask and tgt_id history
    h = tf.transpose(h_tgt_id, [1, 0, 2])
    h = tf.reshape(h, [batch_size, buf_size])

    def unmask(h, m):
        with tf.name_scope('unmask'):
            tpu_summary.tensor('unmask_h', h)
            tpu_summary.tensor('unmask_m', m)
            t = tf.cumsum(m, -1) * m - 1
            mh = einsum_i32('bkt,bt->bkt', m, h)
            t2 = tf.one_hot(tf.cast(t, tf.int32),
                            output_len,
                            dtype=fprop_dtype)
            x = einsum_i32('bkt,bktT->bkT', mh, t2)
            return tf.cast(x, h.dtype)

    topk_ids = unmask(h, nbest_mask)
    topk_len = tf.reduce_sum(nbest_mask, -1)
    topk_len = tf.cast(topk_len, tf.int32)
    # add eos, because nbest_mask does not encode eos
    topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32)
    topk_len += 1
    topk_len = tf.minimum(topk_len, output_len)
    topk_score = nbest_score_norm

    nbest = (topk_ids, topk_len, topk_score)

    return loop_vars, dec_state, nbest
Пример #11
0
  def CornerLoss(self, gt_bboxes, predicted_bboxes):
    """Corner regularization loss.

    This function computes the corner loss, an alternative regression loss
    for box residuals. This was used in the Frustum-PointNets paper [1].

    We compute the predicted bboxes (all 8 corners) and compute a SmoothedL1
    loss between the corners of the predicted boxes and ground truth. Hence,
    this loss can help encourage the model to maximize the IoU of the
    predictions.

    [1] Frustum PointNets for 3D Object Detection from RGB-D Data
        https://arxiv.org/pdf/1711.08488.pdf

    TODO(bcyang): support arbitrary input shapes [..., 7].

    Args:
      gt_bboxes: tf.float32 of shape [batch_size, num_centers,
        num_anchor_bboxes_per_center, 7] which contains (x, y, z, dx, dy, dz,
        phi), corresponding to ground truth bbox parameters.
      predicted_bboxes: tf.float32 of same shape as gt_bboxes containing
        predicted bbox parameters.

    Returns:
      tf.float32 Tensor of shape [batch_size, num_centers,
      num_anchor_bboxes_per_center] where each entry contains the corner loss
      for the corresponding bbox.
    """
    batch_size, num_centers, num_anchor_bboxes_per_center = py_utils.GetShape(
        gt_bboxes, 3)
    gt_bboxes = py_utils.HasShape(
        gt_bboxes, [batch_size, num_centers, num_anchor_bboxes_per_center, 7])
    predicted_bboxes = py_utils.HasShape(
        predicted_bboxes,
        [batch_size, num_centers, num_anchor_bboxes_per_center, 7])

    gt_bboxes = tf.reshape(
        gt_bboxes, [batch_size, num_centers * num_anchor_bboxes_per_center, 7])
    predicted_bboxes = tf.reshape(
        predicted_bboxes,
        [batch_size, num_centers * num_anchor_bboxes_per_center, 7])
    rot = tf.constant([[[0., 0., 0., 0., 0., 0., np.pi]]], dtype=tf.float32)
    rotated_gt_bboxes = gt_bboxes + rot

    gt_corners = geometry.BBoxCorners(gt_bboxes)
    rotated_gt_corners = geometry.BBoxCorners(rotated_gt_bboxes)
    predicted_corners = geometry.BBoxCorners(predicted_bboxes)

    corner_dist = tf.norm(predicted_corners - gt_corners, axis=-1)
    rotated_corner_dist = tf.norm(
        predicted_corners - rotated_gt_corners, axis=-1)
    total_dist = tf.reduce_sum(corner_dist, axis=-1)
    rotated_total_dist = tf.reduce_sum(rotated_corner_dist, axis=-1)
    min_dist = tf.minimum(total_dist, rotated_total_dist)

    huber_loss = self.ScaledHuberLoss(
        labels=tf.zeros_like(total_dist), predictions=min_dist)
    huber_loss = tf.reshape(
        huber_loss, [batch_size, num_centers, num_anchor_bboxes_per_center])

    return huber_loss