Ejemplo n.º 1
0
    def _BuildCrossBatchMixingDataSource(self):
        """Read and return input batch from a p.file_pattern list.

    `p.file_pattern` should be a list of (file_pattern, weight,
    optional_bprop_filter) tuples. Every batch returned will be filled from one
    source only and batches will be mixed proportionally to the weights.
    Additionally some backprop filters may be applied for different input
    sources.

    Returns:
      A tuple which contains the output of `self._DataSourceFromFilePattern()`
      and a tensor of size [batch_size, number of data sources] which contains
      the source selected for each element in the input batch. With cross batch
      mixing the complete input batch comes from the same source.

    Raises:
      ValueError: If unknown token type.
    """
        p = self.params
        input_file_pattern = p.file_pattern

        def _MakeDataSourceFromFilePatternFunc(file_pattern):
            # It's important to invoke self._DataSourceFromFilePattern() inside the
            # lambda to make sure that the record is drawn from data source
            # only if it will be used.
            return lambda: self._DataSourceFromFilePattern(file_pattern)

        inputs = []
        weights = []
        self._bprop_variable_filters = []
        for input_entry in input_file_pattern:
            if isinstance(input_entry, six.string_types):
                raise ValueError(
                    'Should explicitly specify weights, got string: %s' %
                    (input_entry, ))
            file_pattern, weight = input_entry[:2]
            inputs.append(_MakeDataSourceFromFilePatternFunc(file_pattern))
            weights.append(weight)
            bprop_variable_filter = input_entry[2] if len(
                input_entry) > 2 else ''
            self._bprop_variable_filters.append(bprop_variable_filter)
        data_source, selected_bprop = py_utils.MixByWeight(inputs, weights)
        # TODO(neerajgaur): Remove _bprop_onehot and change code that uses it to
        # use source_selected from input_batch.
        self._bprop_onehot = selected_bprop
        batch_size = py_utils.GetShape(tf.nest.flatten(data_source)[0])[0]
        return data_source, tf.tile(tf.expand_dims(selected_bprop, 0),
                                    [batch_size, 1])
Ejemplo n.º 2
0
  def FProp(self, theta, inputs, paddings, domain_ids=None):
    """Applies data augmentation by randomly mask spectrum in inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: A tensor of shape [batch, time, freq, num_channels].
      paddings: A 0/1 tensor of shape [batch, time].
      domain_ids: input domain_ids of shape [batch, time].

    Returns:
      A pair of 2 tensors:

      - augmented_inputs: A tensor of shape [batch, time, freq, num_channels].
      - paddings: A 0/1 tensor of shape [batch, time].
    """
    p = self.params

    batch_size, series_length, num_freq, _ = py_utils.GetShape(inputs)
    if len(p.domain_ids) > 1:
      augmented_inputs = tf.zeros_like(inputs)
      original_inputs = inputs
      for i, domain_id in enumerate(p.domain_ids):
        augmented_domain = self._AugmentationNetwork(
            series_length, num_freq, inputs, paddings, domain_id_index=i)
        target_domain = tf.cast(
            tf.expand_dims(tf.tile([domain_id], [batch_size]), -1),
            dtype=p.dtype)
        # [batch, time]
        domain_mask = tf.cast(
            tf.equal(domain_ids, target_domain), dtype=p.dtype)
        augmented_domain = tf.einsum(
            'bxyc,bx->bxyc',
            augmented_domain,
            domain_mask,
            name='einsum_domainmasking')
        original_inputs = tf.einsum(
            'bxyc,bx->bxyc',
            original_inputs,
            1.0 - domain_mask,
            name='einsum_domainmasking2')
        augmented_inputs = augmented_domain + augmented_inputs
      augmented_inputs = original_inputs + augmented_inputs
    else:
      augmented_inputs = self._AugmentationNetwork(
          series_length, num_freq, inputs, paddings, domain_id_index=0)
    return augmented_inputs, paddings
Ejemplo n.º 3
0
def favor_attention(query,
                    key,
                    value,
                    paddings,
                    kernel_transformation,
                    causal,
                    projection_matrix=None):
    """Computes FAVOR normalized attention.

  Args:
    query: query tensor.
    key: key tensor.
    value: value tensor.
    paddings: paddings tensor.
    kernel_transformation: transformation used to get finite kernel features.
    causal: whether attention is causal or not.
    projection_matrix: projection matrix to be used.

  Returns:
    FAVOR normalized attention.
  """
    query_prime = kernel_transformation(query, True,
                                        projection_matrix)  # [B,L,H,M]
    key_prime = kernel_transformation(key, False,
                                      projection_matrix)  # [B,L,H,M]
    if paddings is not None:
        b, l, h, m = py_utils.GetShape(key_prime, 4)
        paddings = tf.tile(tf.reshape(paddings, [b, l, 1, 1]), [1, 1, h, m])
        key_prime *= tf.cast(1.0 - paddings, key_prime.dtype)
    query_prime = tf.transpose(query_prime, [1, 0, 2, 3])  # [L,B,H,M]
    key_prime = tf.transpose(key_prime, [1, 0, 2, 3])  # [L,B,H,M]
    value = tf.transpose(value, [1, 0, 2, 3])  # [L,B,H,D]
    # TODO(kchoro): Get rid of the transpose operations, at least in the
    # bidirectional variant.

    if causal:
        av_attention = causal_numerator(query_prime, key_prime, value)
        attention_normalizer = causal_denominator(query_prime, key_prime)
    else:
        av_attention = noncausal_numerator(query_prime, key_prime, value)
        attention_normalizer = noncausal_denominator(query_prime, key_prime)
    # TODO(kchoro): Add more comments.
    av_attention = tf.transpose(av_attention, [1, 0, 2, 3])
    attention_normalizer = tf.transpose(attention_normalizer, [1, 0, 2])
    attention_normalizer = tf.expand_dims(attention_normalizer,
                                          len(attention_normalizer.shape))
    return av_attention / attention_normalizer
Ejemplo n.º 4
0
 def testFarthestPointSamplerGatherPoints(self):
   points = tf.constant([
       [[0, 1, 1], [1, 1, 1], [2, 1, 1], [3, 1, 1], [4, 1, 1], [5, 1, 1]],
       [[0, 2, 1], [1, 2, 1], [2, 2, 1], [3, 2, 1], [4, 2, 1], [5, 2, 1]],
       [[0, 2, 3], [1, 2, 3], [2, 2, 3], [3, 2, 3], [4, 2, 3], [5, 2, 3]],
       [[0, 2, 1], [1, 2, 1], [2, 2, 1], [3, 2, 1], [4, 2, 1], [5, 2, 1]],
   ], dtype=tf.float32)  # pyformat: disable
   padding = tf.zeros((4, 6), dtype=tf.float32)
   n = 4
   num_points = 3
   selected_idx, _ = car_lib.FarthestPointSampler(points, padding, num_points)
   gather_indices = tf.stack([
       tf.tile(tf.expand_dims(tf.range(n), 1), [1, num_points]), selected_idx
   ],
                             axis=2)
   sampled_points = tf.gather_nd(points, gather_indices)
   with self.session() as sess:
     sampled_points = sess.run(sampled_points)
     self.assertEqual(sampled_points.shape, (n, num_points, 3))
Ejemplo n.º 5
0
def SparseToDense(grid_shape, locations, feats):
  """Converts a sparse representation back to the dense grid.

  Args:
    grid_shape: (nx, ny, nz). The shape of the grid.
    locations: [b, p, 3]. Locations of the pillars.
    feats: [b, p, fdims]. Extracted features for pillars.

  Returns:
    grid_feats: [b, nx, ny, nz * fdims].
  """
  nx, ny, nz = grid_shape
  b, p, _ = py_utils.GetShape(locations, 3)
  feats = py_utils.HasShape(feats, [b, p, -1])
  _, _, fdims = py_utils.GetShape(feats, 3)
  indices = tf.concat(
      [tf.tile(tf.range(b)[:, tf.newaxis, tf.newaxis], [1, p, 1]), locations],
      axis=2)
  grid = tf.scatter_nd(indices, feats, [b, nx, ny, nz, fdims])
  return tf.reshape(grid, [b, nx, ny, nz * fdims])
Ejemplo n.º 6
0
  def _GetWeight(self, theta):
    p = self.params
    filter_w = theta.w

    # First normalize filter_w over the temporal dimension here.
    filter_w = tf.nn.softmax(filter_w / p.temperature, axis=0)

    # Add dropconnect on the weights for regularization.
    if p.dropconnect_prob > 0.0 and not self.do_eval:
      if p.deterministic_dropout:
        filter_w = py_utils.DeterministicDropout(
            filter_w, 1.0 - p.dropconnect_prob,
            py_utils.GenerateStepSeedPair(p))
      else:
        filter_w = tf.nn.dropout(
            filter_w, rate=p.dropconnect_prob, seed=p.random_seed)

    # Tie the parameters of every subsequent number of weight_tiling_factor
    # channels.
    filter_w = tf.tile(filter_w, [1, 1, p.weight_tiling_factor, 1])
    return filter_w
Ejemplo n.º 7
0
    def testMaxCanvasSizeUnderUniformRollinPolicy(self):
        """Tests for valid canvas size."""
        with self.session(use_gpu=True) as sess:
            params = insertion.SymbolInsertionLayer.Params()
            params.name = 'insertion'
            params.rollin_policy = 'oracle'
            params.oracle_policy = 'uniform'

            insertion_layer = insertion.SymbolInsertionLayer(params)

            batch_size = 4
            time_dim = 10

            inputs = tf.tile(tf.expand_dims(tf.range(time_dim), 0),
                             [batch_size, 1])
            inputs_len = tf.random.uniform([batch_size], 0, time_dim, tf.int32)
            paddings = 1 - tf.sequence_mask(inputs_len, time_dim, tf.int32)
            spec = insertion_layer.FProp(None,
                                         inputs,
                                         paddings,
                                         force_sample_last_token=False)

            canvas_with_max_length = False
            for _ in range(1000):
                canvas_max_len, canvas, canvas_paddings = sess.run(
                    [inputs_len, spec.canvas, spec.canvas_paddings])

                for b in range(batch_size):
                    max_len = canvas_max_len[b]
                    length = np.sum(1 - canvas_paddings[b, :]).astype(np.int32)
                    canvas_with_max_length |= length == max_len
                    self.assertLessEqual(length, max_len)
                    # Invalid entries of canvas should be 0.
                    self.assertAllEqual(canvas[b, length:],
                                        [0] * (canvas.shape[1] - length))

            # With high probability, there should be at least one canvas that is
            # of the same size as the maximum canvas size.
            self.assertEqual(canvas_with_max_length, True)
Ejemplo n.º 8
0
    def _ReshapeToMono2D(self, pcm_audio_data, paddings):
        """Reshapes a 3D or 4D input to 2D.

    Since the input to FProp can be 3D or 4D (see class comments), this will
    collapse it back to a 2D, mono shape for internal processing.

    Args:
      pcm_audio_data: 2D, 3D or 4D audio input. See class comments. Must have a
        rank.
      paddings: Original paddings shaped to the first two dims of
        pcm_audio_data.

    Returns:
      Tuple of 2D [batch_size, timestep] mono audio data, new paddings.
    """
        shape = py_utils.GetShape(pcm_audio_data)
        rank = len(shape)
        if rank == 2:
            return pcm_audio_data, paddings
        elif rank == 3:
            # [batch, time, channel]
            with tf.control_dependencies([tf.assert_equal(shape[2], 1)]):
                return tf.squeeze(pcm_audio_data, axis=2), paddings
        elif rank == 4:
            # [batch, time, packet, channel]
            batch_size, orig_time, orig_packet_size, channel = shape
            time = orig_time * orig_packet_size
            with tf.control_dependencies([tf.assert_equal(channel, 1)]):
                pcm_audio_data = tf.reshape(pcm_audio_data, (batch_size, time))
                # Transform paddings into the new time base with a padding per time
                # step vs per packet by duplicating each packet.
                paddings = tf.reshape(
                    tf.tile(tf.expand_dims(paddings, axis=2),
                            [1, 1, orig_packet_size]), (batch_size, time))
                return pcm_audio_data, paddings
        else:
            raise ValueError('Illegal pcm_audio_data shape')
Ejemplo n.º 9
0
    def testContiguousCanvasUnderUniformRollinPolicy(self):
        """Tests for valid canvas size."""
        with self.session(use_gpu=True) as sess:
            params = insertion.SymbolInsertionLayer.Params()
            params.name = 'insertion'
            params.rollin_policy = 'oracle'
            params.oracle_policy = 'uniform'

            insertion_layer = insertion.SymbolInsertionLayer(params)

            batch_size = 4
            time_dim = 10

            inputs = tf.tile(
                tf.expand_dims(tf.range(time_dim), 0) + 100, [batch_size, 1])
            inputs_len = tf.random.uniform([batch_size], 0, time_dim, tf.int32)
            paddings = 1 - tf.sequence_mask(inputs_len, time_dim, tf.int32)
            spec = insertion_layer.FProp(None,
                                         inputs,
                                         paddings,
                                         force_sample_last_token=False)

            for _ in range(1000):
                canvas, canvas_paddings = sess.run(
                    [spec.canvas, spec.canvas_paddings])

                for b in range(batch_size):
                    length = np.sum(1 - canvas_paddings[b, :]).astype(np.int32)
                    # Check for valid part of the canvas and padding.
                    for l in range(length):
                        self.assertEqual(canvas_paddings[b, l], 0)
                        self.assertNotEqual(canvas[b, l], 0)
                    # Check for invalid part of the canvas and padding.
                    for l in range(length, canvas.shape[1]):
                        self.assertEqual(canvas_paddings[b, l], 1)
                        self.assertEqual(canvas[b, l], 0)
Ejemplo n.º 10
0
    def FProp(self, theta, input_batch):
        # pyformat: disable
        """Compute features for the pillars and convert them back to a dense grid.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      input_batch: A `.NestedMap` object containing input tensors. Following
        keys are required:

        - grid_num_points: Integer tensor with shape [batch size, nx, ny, nz],
          where nx, ny, nz corresponds to the grid sizes (i.e., number of voxels
          in each axis dimension).
        - pillar_points: Float tensor with shape [batch size, num_pillars,
          num_points_per_pillar, 3 + num_laser_features]
        - pillar_centers: Float tensor with shape [batch size, num_pillars,
          num_points_per_pillar, 3]
        - pillar_locations: Float tensor with shape [batch size, num_pillars, 3]

    Returns:
      The dense features with shape [b, nx, ny, nz * fdims].
    """
        # pyformat: enable
        p = self.params
        bs, nx, ny, nz = py_utils.GetShape(input_batch.grid_num_points, 4)
        # Process points to concatenate a set of fixed features (e.g.,
        # add means, centers, normalize points to means).
        num_features = 3 + p.num_laser_features
        pillar_points = py_utils.HasShape(input_batch.pillar_points,
                                          [bs, -1, -1, num_features])
        _, npillars, npoints, _ = py_utils.GetShape(pillar_points, 4)
        pillar_xyz = pillar_points[..., :3]

        # Compute number of points per pillar and prepare for broadcasting.
        pillar_num_points = tf.gather_nd(input_batch.grid_num_points,
                                         input_batch.pillar_locations,
                                         batch_dims=1)
        pillar_num_points = pillar_num_points[..., tf.newaxis, tf.newaxis]

        # Compute mean by computing sum and dividing by number of points. Clip the
        # denominator by 1.0 to gracefully handle empty pillars.
        pillar_sum = tf.reduce_sum(pillar_xyz, axis=2, keep_dims=True)
        pillar_means = pillar_sum / tf.maximum(
            tf.cast(pillar_num_points, tf.float32), 1.0)

        pillar_feats = pillar_points[..., 3:]
        pillar_centers = py_utils.HasShape(input_batch.pillar_centers,
                                           [bs, -1, 1, 3])
        pillar_concat = tf.concat(axis=3,
                                  values=[
                                      pillar_xyz - pillar_means, pillar_feats,
                                      tf.tile(pillar_means,
                                              [1, 1, npoints, 1]),
                                      tf.tile(pillar_centers,
                                              [1, 1, npoints, 1])
                                  ])
        # Featurize pillars.
        pillar_features = self.featurizer.FProp(theta.featurizer,
                                                pillar_concat)

        # Convert back to the dense grid.
        pillar_locations = py_utils.HasShape(input_batch.pillar_locations,
                                             [bs, npillars, 3])
        dense_features = SparseToDense(grid_shape=(nx, ny, nz),
                                       locations=pillar_locations,
                                       feats=pillar_features)
        return dense_features
Ejemplo n.º 11
0
  def _XYZFromRangeImage(self,
                         lidar_image,
                         lidar_image_mask,
                         extrinsics,
                         inclinations,
                         pixel_pose=None,
                         frame_pose=None):
    """Extract the cartesian coordinates from the range image.

    Args:
       lidar_image: [H, W, C] range image Tensor.
       lidar_image_mask: [H, W] boolean indicating which 2d coordinates in the
         lidar image are present.
       extrinsics: [4, 4] float matrix representing transformation matrix to
         world coordinates.
       inclinations: [V] beam inclinations vector.
       pixel_pose: [64, 2650, 4, 4] tensor representing per pixel pose of GBR.
       frame_pose: [4, 4] matrix representing vehicle to world transformation.

    Returns:
      [H, W, 3] range image cartesian coordinates.
    """
    height, width, channels = py_utils.GetShape(lidar_image, 3)

    conversion_dtype = tf.float32
    lidar_image = tf.cast(lidar_image, conversion_dtype)
    extrinsics = tf.cast(extrinsics, conversion_dtype)
    inclinations = tf.cast(inclinations, conversion_dtype)
    inclinations = tf.reverse(inclinations, axis=[-1])

    az_correction = py_utils.HasShape(
        tf.atan2(extrinsics[1, 0], extrinsics[0, 0]), [])
    ratios = (tf.cast(tf.range(width, 0, -1), dtype=conversion_dtype) -
              .5) / tf.cast(width, conversion_dtype)
    ratios = py_utils.HasShape(ratios, [width])

    azimuth = (ratios * 2. - 1.) * np.pi - az_correction[..., tf.newaxis]
    azimuth = py_utils.HasShape(azimuth, [width])

    lidar_image_mask = lidar_image_mask[..., tf.newaxis]
    lidar_image_mask = tf.tile(lidar_image_mask, [1, 1, channels])
    lidar_image = tf.where(lidar_image_mask, lidar_image,
                           tf.zeros_like(lidar_image))
    lidar_image_range = lidar_image[..., 0]

    azimuth = py_utils.HasShape(azimuth[tf.newaxis, ...], [1, width])
    inclinations = py_utils.HasShape(inclinations[..., tf.newaxis], [height, 1])

    cos_azimuth = tf.cos(azimuth)
    sin_azimuth = tf.sin(azimuth)
    cos_incl = tf.cos(inclinations)
    sin_incl = tf.sin(inclinations)

    x = cos_azimuth * cos_incl * lidar_image_range
    y = sin_azimuth * cos_incl * lidar_image_range
    z = sin_incl * lidar_image_range

    lidar_image_points = tf.stack([x, y, z], -1)
    lidar_image_points = py_utils.HasShape(lidar_image_points,
                                           [height, width, 3])
    rotation = extrinsics[0:3, 0:3]
    translation = extrinsics[0:3, 3][tf.newaxis, ...]

    # Transform the image points in cartesian coordinates to
    # the world coordinate system using the extrinsics matrix.
    #
    # We first flatten the points, apply rotation, then
    # reshape to restore the original input and then apply
    # translation.
    lidar_image_points = tf.matmul(
        tf.reshape(lidar_image_points, [-1, 3]), rotation, transpose_b=True)
    lidar_image_points = tf.reshape(lidar_image_points, [height, width, 3])
    lidar_image_points += translation

    lidar_image_points = py_utils.HasShape(lidar_image_points,
                                           [height, width, 3])
    # GBR uses per pixel pose.
    if pixel_pose is not None:
      pixel_pose_rotation = pixel_pose[..., 0:3, 0:3]
      pixel_pose_translation = pixel_pose[..., 0:3, 3]
      lidar_image_points = tf.einsum(
          'hwij,hwj->hwi', pixel_pose_rotation,
          lidar_image_points) + pixel_pose_translation
      if frame_pose is None:
        raise ValueError('frame_pose must be set when pixel_pose is set.')
      # To vehicle frame corresponding to the given frame_pose
      # [4, 4]
      world_to_vehicle = tf.linalg.inv(frame_pose)
      world_to_vehicle_rotation = world_to_vehicle[0:3, 0:3]
      world_to_vehicle_translation = world_to_vehicle[0:3, 3]
      # [H, W, 3]
      lidar_image_points = tf.einsum(
          'ij,hwj->hwi', world_to_vehicle_rotation,
          lidar_image_points) + world_to_vehicle_translation[tf.newaxis,
                                                             tf.newaxis, :]

    return lidar_image_points
Ejemplo n.º 12
0
    def _GetMask(self,
                 batch_size,
                 choose_range,
                 mask_size,
                 global_seed,
                 max_length=None,
                 masks_per_frame=0.0,
                 multiplicity=1,
                 dtype=tf.float32,
                 max_ratio=1.0):
        """Returns fixed size multi-masks starting from random positions.

    A multi-mask is a mask obtained by applying multiple masks.

    This function when max_length is given:
      1) Sample random mask lengths less than max_length with shape
         (batch_size, multiplicity).
      2) Truncate lengths to a max of (choose_range * max_ratio),
         so that each mask is fully contained within the corresponding sequence.
      3) Random sample start points of shape (batch_size, multiplicity)
         with in (choose_range - lengths).
      4) For each batch, multiple masks (whose number is given by the
         multiplicity) are constructed.
      5) Return a mask of shape (batch_size, mask_size) where masks are
         obtained by composing the masks constructed in step 4).
         If masks_per_frame > 0, the number is given by
         min(masks_per_frame * choose_range, multiplicity).
         If not, all the masks are composed. The masked regions are set to zero.

    This function when max_length is not given:
      1) Sample random mask lengths less than (choose_range * max_ratio)
         with shape (batch_size, multiplicity).
      2) Proceed to steps 3), 4) and 5) of the above.

    Args:
      batch_size: Batch size. Integer number.
      choose_range: Range within which the masked entries must lie. Tensor of
        shape (batch_size,).
      mask_size: Size of the mask. Integer number.
      global_seed: an integer seed tensor for stateless random ops.
      max_length: Maximum number of allowed consecutive masked entries. Integer
        number or None.
      masks_per_frame: Number of masks per frame. Float number. If > 0, the
        multiplicity of the mask is set to be masks_per_frame * choose_range.
      multiplicity: Maximum number of total masks. Integer number.
      dtype: Data type.
      max_ratio: Maximum portion of the entire range allowed to be masked. Float
        number.

    Returns:
      mask: a fixed size multi-mask starting from a random position with shape
      (batch_size, mask_size).
    """
        p = self.params
        # Non-empty random seed values are only used for testing or when using
        # stateless random ops. seed_1 and seed_2 are set separately to avoid
        # correlation of mask size and mask position.
        if p.use_input_dependent_random_seed:
            seed_1 = global_seed + 1
            seed_2 = global_seed + 2
        elif p.random_seed:
            seed_1 = p.random_seed + 1
            seed_2 = 2 * p.random_seed
        else:
            seed_1 = p.random_seed
            seed_2 = p.random_seed
        # Sample lengths for multiple masks.
        if max_length and max_length > 0:
            max_length = tf.broadcast_to(tf.cast(max_length, dtype),
                                         (batch_size, ))
        else:
            max_length = tf.cast(choose_range, dtype=dtype) * max_ratio
        random_uniform = _random_uniform_op(p.use_input_dependent_random_seed)
        masked_portion = random_uniform(shape=(batch_size, multiplicity),
                                        minval=0.0,
                                        maxval=1.0,
                                        dtype=dtype,
                                        seed=seed_1)
        masked_frame_size = self.EinsumBBmBm(max_length, masked_portion)
        masked_frame_size = tf.cast(masked_frame_size, dtype=tf.int32)
        # Make sure the sampled length was smaller than max_ratio * length_bound.
        # Note that sampling in this way was biased
        # (shorter sequence may over-masked.)
        choose_range = tf.expand_dims(choose_range, -1)
        choose_range = tf.tile(choose_range, [1, multiplicity])
        length_bound = tf.cast(choose_range, dtype=dtype)
        length_bound = tf.cast(max_ratio * length_bound, dtype=tf.int32)
        length = tf.minimum(masked_frame_size, tf.maximum(length_bound, 1))

        # Choose starting point.
        random_start = random_uniform(shape=(batch_size, multiplicity),
                                      maxval=1.0,
                                      seed=seed_2)
        start_with_in_valid_range = random_start * tf.cast(
            (choose_range - length + 1), dtype=dtype)
        start = tf.cast(start_with_in_valid_range, tf.int32)
        end = start + length - 1

        # Shift starting and end point by small value.
        delta = tf.constant(0.1)
        start = tf.expand_dims(tf.cast(start, dtype) - delta, -1)
        start = tf.tile(start, [1, 1, mask_size])
        end = tf.expand_dims(tf.cast(end, dtype) + delta, -1)
        end = tf.tile(end, [1, 1, mask_size])

        # Construct pre-mask of shape (batch_size, multiplicity, mask_size).
        diagonal = tf.expand_dims(
            tf.expand_dims(tf.cast(tf.range(mask_size), dtype=dtype), 0), 0)
        diagonal = tf.tile(diagonal, [batch_size, multiplicity, 1])
        pre_mask = tf.cast(tf.math.logical_and(diagonal < end,
                                               diagonal > start),
                           dtype=dtype)

        # Sum masks with appropriate multiplicity.
        if masks_per_frame > 0:
            multiplicity_weights = tf.tile(
                tf.expand_dims(tf.range(multiplicity, dtype=dtype), 0),
                [batch_size, 1])
            multiplicity_tensor = masks_per_frame * tf.cast(choose_range,
                                                            dtype=dtype)
            multiplicity_weights = tf.cast(
                multiplicity_weights < multiplicity_tensor, dtype=dtype)
            pre_mask = self.EinsumBmtBmBt(pre_mask, multiplicity_weights)
        else:
            pre_mask = tf.reduce_sum(pre_mask, 1)
        mask = tf.cast(1.0 - tf.cast(pre_mask > 0, dtype=dtype), dtype=dtype)

        if p.fprop_dtype is not None and p.fprop_dtype != p.dtype:
            mask = tf.cast(mask, p.fprop_dtype)

        return mask
Ejemplo n.º 13
0
    def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx):
        """Loop body for farthest point sampler."""
        def _GetRandomRealPoint():
            """Select the first point.

      For the first point, we want any random real (non padded) point, so we
      create a random values per point, and then set all padded ones to
      some large value (more than the maxval). We then take the min per batch
      element to get the first points.

      Returns:
        Tensor containing the index of a random point selected for each example
        in the batch.
      """
            random_values = tf.random.uniform((batch_size, num_points),
                                              minval=0,
                                              maxval=1,
                                              dtype=tf.float32,
                                              seed=random_seed)
            random_values = tf.where(tf.equal(padding, 0.0), random_values,
                                     padding * 10)
            return tf.argmin(random_values, axis=1, output_type=tf.int32)

        def _GetFurthestPoint():
            """Get point that is furthest from those already selected.

      We also bias the sampling towards real points by setting the distance
      to padded points negative until we are out of real points.

      Returns:
        Tensor containing the index of the next farthest point selected for each
        example in the batch.
      """
            # Set padded points distance to negative so they aren't selected.
            padding_masked_distance_to_selected = tf.where(
                tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones(
                    (batch_size, num_points), dtype=tf.float32))
            # But only do this when we still have valid points left.
            padding_masked_distance_to_selected = tf.where(
                tf.less(curr_idx, num_valid_points),
                padding_masked_distance_to_selected, distance_to_selected)
            return tf.argmax(padding_masked_distance_to_selected,
                             axis=-1,
                             output_type=tf.int32)

        def _GetSeededPoint():
            """Select a seeded point.

      Seeded points are assumed to be at the beginning of the original points.

      Returns:
        Tensor containing the index of the next seeded point to select for each
        example in the batch.
      """
            return tf.ones((batch_size, ), dtype=tf.int32) * curr_idx

        # Select indices for this loop iteration.
        def _Seeded():
            return tf.cond(tf.less(curr_idx, num_seeded_points),
                           _GetSeededPoint, _GetFurthestPoint)

        def _Real():
            return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint,
                           _GetFurthestPoint)

        new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded,
                               _Real)
        sampled_idx = sampled_idx.write(curr_idx, new_selected)

        # Extract the distance to the latest point selected to update
        # distance_to_selected.
        new_selected_gather_idx = tf.stack(
            [tf.range(batch_size), new_selected], axis=1)
        if precomputed_squared_distance is not None:
            new_distance = tf.gather_nd(precomputed_squared_distance,
                                        new_selected_gather_idx)
        else:
            new_points = tf.reshape(
                tf.gather_nd(points, new_selected_gather_idx),
                [batch_size, 1, dims])
            new_distance = tf.reshape(
                SquaredDistanceMatrix(points, new_points),
                [batch_size, num_points])

        is_newly_closest = tf.less(new_distance, distance_to_selected)
        distance_to_selected = tf.minimum(distance_to_selected, new_distance)

        # Track the index to the closest selected point.
        new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points])
        closest_idx = tf.cond(
            tf.equal(curr_idx, 0),
            # At the first loop iteration, the init points are the closest.
            lambda: new_selected_tiled,
            # Otherwise, update with the new points based on the distances.
            lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx)
        )
        return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx
Ejemplo n.º 14
0
    def ComputeMetrics(self, decoder_outs, input_batch, ids_to_strings_fn):
        """Computes metrics on output from decoder.

    Args:
      decoder_outs: A `BeamSearchDecodeOutput`, a namedtuple containing the
        decode results.
      input_batch:  A `NestedMap` of tensors representing the source, target,
        and other components of the input batch.
      ids_to_strings_fn: a function of (ids, lens) -> strings, where ids has
        shape [batch, length], lens has shape [batch], and strings has shape
        [batch].

    Returns:
      A dict of Tensors containing decoder output and metrics.
    """
        topk = self.GetTopK(decoder_outs, ids_to_strings_fn=ids_to_strings_fn)
        tgt_batch = tf.shape(topk.scores)[0]
        num_hyps_per_beam = tf.shape(topk.scores)[1]
        tgt = input_batch.tgt
        tgt_lens = tf.cast(tf.round(tf.reduce_sum(1.0 - tgt.paddings, 1)),
                           tf.int32)
        tgt_lens = py_utils.HasShape(tgt_lens, [tgt_batch])
        transcripts = ids_to_strings_fn(tgt.labels, tgt_lens - 1)

        # Filter out all isolated '<noise>' tokens.
        noise_pattern = ' <noise> |^<noise> | <noise>$|^<noise>$'
        filtered_refs = tf.strings.regex_replace(transcripts, noise_pattern,
                                                 ' ')
        filtered_hyps = tf.strings.regex_replace(topk.decoded, noise_pattern,
                                                 ' ')
        # Compute translation quality scores for all hyps.
        filtered_refs = tf.tile(tf.reshape(filtered_refs, [-1, 1]),
                                [1, num_hyps_per_beam])
        filtered_hyps = tf.reshape(filtered_hyps, [-1])
        filtered_refs = tf.reshape(filtered_refs, [-1])
        tf.logging.info('filtered_refs=%s', filtered_refs)
        norm_wer_errors, norm_wer_words = self.ComputeNormalizedWER(
            filtered_hyps, filtered_refs, num_hyps_per_beam)

        ret_dict = {
            'target_ids': tgt.ids,
            'target_labels': tgt.labels,
            'target_weights': tgt.weights,
            'target_paddings': tgt.paddings,
            'transcripts': transcripts,
            'topk_decoded': topk.decoded,
            'topk_ids': topk.ids,
            'topk_lens': topk.lens,
            'topk_scores': topk.scores,
            'norm_wer_errors': norm_wer_errors,
            'norm_wer_words': norm_wer_words,
        }

        if not py_utils.use_tpu() and 'sample_ids' in input_batch:
            ret_dict['utt_id'] = input_batch.sample_ids

        ret_dict.update(
            self.AddAdditionalDecoderMetricsToGraph(topk, filtered_hyps,
                                                    filtered_refs, input_batch,
                                                    decoder_outs))
        return ret_dict
Ejemplo n.º 15
0
  def FProp(self, theta, x, paddings=None, update=False):
    """Computes distances of the given input 'x' to all centroids.

    This implementation applies layer normalization on 'x' internally first,
    and the returned 'dists' is computed using the normalized 'x'.

    Args:
      theta: A `.NestedMap` of weights' values of this layer.
      x: A tensor of shape [B, L, N, H].
      paddings: If not None, a tensor of shape [B, L].
      update: bool, whether to update centroids using x.

    Returns:
      dists: "distances" of the given input 'x' to all centroids.
             Shape [B, L, N, K].
      k_means_loss: the average squared Euclidean distances to the closest
                    centroid, a scalar.
    """
    p = self.params
    if paddings is None:
      paddings = tf.zeros_like(x[:, :, 0, 0])
    # Shape [B, L, 1, 1]
    paddings_4d = paddings[:, :, None, None]

    if p.apply_layer_norm:
      x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon)

    # 'x' is normalized (but theta.means is not), we use negative dot product to
    # approximate the Euclidean distance here.
    dists = -tf.einsum('BLNH, NKH -> BLNK', x, theta.means)

    # For padded positions we update the distances to very large numbers.
    very_large_dists = tf.ones_like(dists) * tf.constant(
        0.1, dtype=dists.dtype) * dists.dtype.max
    paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters])
    dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists)

    # Shape [B, L, N, K], the same as 'dists' above.
    nearest_one_hot = tf.one_hot(
        tf.math.argmin(dists, axis=-1),
        p.num_clusters,
        dtype=py_utils.FPropDtype(p))
    # Same shape as the input 'x'.
    nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot,
                                 theta.means)
    diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid))
    diff = py_utils.ApplyPadding(paddings_4d, diff)
    diff = tf.math.reduce_mean(diff, axis=2)

    # The commitment loss which when back proped against encourages the 'x'
    # values to commit to their chosen centroids.
    k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings)
    summary_utils.scalar('k_means/squared_distance_loss', k_means_loss)

    # TODO(zhouwk): investigate normalizing theta.means after each update.
    means_norm = tf.norm(theta.means)
    summary_utils.scalar('k_means/centroid_l2_norm/min',
                         tf.math.reduce_min(means_norm))
    summary_utils.scalar('k_means/centroid_l2_norm/mean',
                         tf.math.reduce_mean(means_norm))

    if not update:
      return dists, k_means_loss

    # To update the centroids (self.vars.means), we apply gradient descent on
    # the mini-batch of input 'x', which yields the following:
    #   new_centroid = centroid + (1 - decay) * (x_mean - centroid)
    # where x_mean is the average over all the input vectors closest to this
    # centroid.
    #
    # Note that this approach is equivalent with backprop via
    #    loss = tf.math.reduce_mean(
    #        tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid)))
    # , except that here the learning rate is independently set via 'decay'.

    # Ensure that the padded positions are not used to update the centroids.
    nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot)

    # Sum away batch and sequence length dimensions to get per cluster count.
    # Shape: [N, K]
    per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1])
    summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count)

    # Sum of the input 'x' per each closest centroid.
    sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x)

    if py_utils.use_tpu():
      per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count)
      sum_x = tf.tpu.cross_replica_sum(sum_x)

    # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that
    # cluster's position will always be 0, hence 'sum_x' in that dimension will
    # be 0.
    new_means = sum_x / tf.maximum(
        tf.constant(1.0, dtype=per_cluster_count.dtype),
        tf.expand_dims(per_cluster_count, axis=-1))

    # We use exponential moving average. TODO(zhouwk): investigate smooth this
    # over an exponentially moving averaged per cluster count.
    #
    # Note that we intentionally do not normalize the means after this update
    # as empirically this works better.
    update_means_diff = tf.cast((1.0 - p.decay) * (new_means - theta.means),
                                self.vars.means.dtype)
    return py_utils.with_dependencies(
        [tf.assign_add(self.vars.means, update_means_diff)],
        dists), k_means_loss
Ejemplo n.º 16
0
  def testConv2DLayerStridedWithPaddingFProp(self, seq_len):
    """Check strided convs get the same values for different length dim."""
    # TODO(isaace): THIS TEST SHOWS THAT THERE IS A BUG IN THE CODE.
    with self.session(use_gpu=True):
      batch_size = 3
      expected_seq_len = 3

      params = conv_layers.Conv2DLayerWithPadding.Params()
      params.weight_norm = False
      params.filter_stride = [2, 2]
      params.name = 'conv'
      params.filter_shape = [3, 3, 1, 1]
      params.params_init = py_utils.WeightInit.Constant(1.0)
      conv_layer = params.Instantiate()

      # Set up the padding for the sequence length. (starting at 5).
      in_padding = tf.constant([
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 1],
          [0, 0, 0, 1, 1],
      ], tf.float32)
      in_padding = tf.pad(
          in_padding, [[0, 0], [0, seq_len - 5]], constant_values=1.0)

      inputs = 1.0 + tf.tile(
          tf.reshape(tf.range(seq_len, dtype=tf.float32), [1, seq_len, 1, 1]),
          [batch_size, 1, 3, 1])
      inputs = py_utils.ApplyPadding(
          tf.reshape(in_padding, [batch_size, seq_len, 1, 1]), inputs)

      inputs = py_utils.Debug(inputs)

      output, out_padding = conv_layer.FPropDefaultTheta(inputs, in_padding)

      output = py_utils.Debug(output)
      out_padding = py_utils.Debug(out_padding)

      self.evaluate(tf.global_variables_initializer())
      output, out_padding = self.evaluate([output, out_padding])

      self.assertEqual((batch_size, expected_seq_len, 2, 1), output.shape)
      self.assertAllClose([
          [0, 0, 1],
          [0, 0, 1],
          [0, 1, 1],
      ], out_padding)

      # This here shows a bug in the implementation; the output should be the
      # same. Also there are bugs with the output not having the correct
      # padding.
      if seq_len == 5:
        self.assertAllClose([
            [[[6], [6]], [[18], [18]], [[18], [18]]],
            [[[6], [6]], [[18], [18]], [[8], [8]]],
            [[[6], [6]], [[10], [10]], [[0], [0]]],
        ], output)
      elif seq_len == 6:
        self.assertAllClose([
            [[[12], [12]], [[24], [24]], [[10], [10]]],
            [[[12], [12]], [[14], [14]], [[0], [0]]],
            [[[12], [12]], [[6], [6]], [[0], [0]]],
        ], output)
      else:
        raise ValueError('Test does not handle length {seq_len}')
Ejemplo n.º 17
0
def _SingleClassDecodeWithNMS(predicted_bboxes,
                              classification_scores,
                              nms_iou_threshold,
                              score_threshold,
                              max_boxes_per_class=None):
    """Perform NMS on predicted bounding boxes / associated logits.

  Args:
    predicted_bboxes: [batch_size, num_boxes, 7] float Tensor containing
      predicted bounding box coordinates.
    classification_scores: [batch_size, num_boxes, num_classes] float Tensor
      containing predicted classification scores for each box.
    nms_iou_threshold: IoU threshold to use when determining whether two boxes
      overlap for purposes of suppression.
    score_threshold: The score threshold passed to NMS that allows NMS to
      quickly ignore irrelevant boxes.
    max_boxes_per_class: The maximum number of boxes per example to emit. If
      None, this value is set to num_boxes from the shape of predicted_bboxes.

  Returns:
    predicted_bboxes: Filtered bboxes after NMS of shape
      [batch_size, num_classes, max_boxes_per_class, 7].
    bbox_scores: A float32 Tensor with the score for each box of shape
      [batch_size, num_classes, max_boxes_per_class].
    valid_mask: A float32 Tensor with 1/0 values indicating the validity of
      each box. 1 indicates valid, and 0 invalid. Tensor of shape
      [batch_size, num_classes, max_boxes_per_class].
  """
    utils_3d = detection_3d_lib.Utils3D()
    predicted_bboxes = py_utils.HasShape(predicted_bboxes, [-1, -1, 7])
    batch_size, num_predicted_boxes, _ = py_utils.GetShape(predicted_bboxes)
    classification_scores = py_utils.HasShape(
        classification_scores, [batch_size, num_predicted_boxes, -1])
    _, _, num_classes = py_utils.GetShape(classification_scores)

    if not isinstance(nms_iou_threshold, float):
        raise ValueError('Single class NMS only supports a scalar '
                         '`nms_iou_threshold`.')
    if not isinstance(score_threshold, float):
        raise ValueError('Single class NMS only supports a scalar '
                         '`score_threshold`.')

    if max_boxes_per_class is None:
        max_boxes_per_class = num_predicted_boxes

    # TODO(jngiam): Change to be per-class bboxes, and hence, per-class NMS, and
    # per-class thresholding.
    # [batch, num_predicted_boxes]
    nms_scores = tf.reduce_max(classification_scores, axis=-1)

    # Compute the most likely label by computing the highest class score from
    # the output of the sigmoid.
    likely_labels = tf.argmax(classification_scores, axis=-1)

    # When background is the most likely class for the box, mask out the scores
    # of that box from NMS scoring so the background boxes don't dominate the
    # NMS.
    nms_scores *= tf.cast(likely_labels > 0, tf.float32)

    # Compute NMS for every sample in the batch.
    nms_indices, valid_mask = utils_3d.BatchedNMSIndices(
        predicted_bboxes,
        nms_scores,
        nms_iou_threshold=nms_iou_threshold,
        score_threshold=score_threshold,
        max_num_boxes=max_boxes_per_class)

    # Reorder the box data and logits according to NMS scoring.
    predicted_bboxes = tf.array_ops.batch_gather(predicted_bboxes, nms_indices)
    classification_scores = tf.array_ops.batch_gather(classification_scores,
                                                      nms_indices)

    # Now reformat the output of NMS to match the format of the
    # MultiClassOrientedDecodeWithNMS, which outputs a per class NMS result.
    # This takes the leading shape of
    # [batch_size, num_classes, max_boxes_per_class] for all outputs, which
    # means since this NMS is not class specific we need to tile the outputs
    # num_classes times or reorder the data such that its [batch, num_classes].
    predicted_bboxes = tf.tile(predicted_bboxes[:, tf.newaxis, :, :],
                               [1, num_classes, 1, 1])
    classification_scores = tf.transpose(classification_scores, (0, 2, 1))
    classification_scores = py_utils.HasShape(
        classification_scores, [batch_size, num_classes, max_boxes_per_class])
    valid_mask = tf.tile(valid_mask[:, tf.newaxis, :], [1, num_classes, 1])
    return predicted_bboxes, classification_scores, valid_mask
    def ComputePredictions(self, theta, input_batch):
        p = self.params
        batch_size = p.input.batch_size
        self._shape_batch(input_batch)

        # Prepend SOS token, this is not done by the Transformer layer for you
        # since this is usually done by the input pipeline in Babelfish.
        pronunciation = self._AddStartToken(input_batch.pronunciation)

        if p.use_neighbors:
            spellings = input_batch.neighbor_spellings
            pronunciations = input_batch.neighbor_pronunciations

        inp = {
            "ids": input_batch.spelling,
        }

        if (p.use_neighbors and p.also_shuffle_neighbors
                and (p.neigh_att_type == "CONCAT" or p.use_neigh_id_emb)):
            # If we use neighbor IDs, shuffle the neighbours to stop the model
            # overfitting to the ordering of the neighbours.
            # Concat then shuffle and split so that the spelling and pronunciation
            # are shuffled the same way and the IDs are aligned.
            neighbor_info = tf.concat([spellings, pronunciations], axis=-1)
            # Transpose the max_neighbors dimension to the front and shuffle.
            neighbor_info = tf.transpose(
                tf.random.shuffle(tf.transpose(neighbor_info, (1, 2, 0))),
                (2, 0, 1))
            spellings, pronunciations = (
                neighbor_info[:, :, :p.max_spelling_len],
                neighbor_info[:, :, p.max_spelling_len:])

        if p.use_neighbors and p.neigh_att_type == "CONCAT":
            # Interleave and flatten the neighbours info
            # ->(batch_size, max_neighbors, max_spelling_len + max_pronunciation_len)
            neigh_info = tf.concat([spellings, pronunciations], axis=2)
            # ->(batch_size, max_neighbors*(max_spelling_len + max_pronunciation_len))
            neigh_info = tf.reshape(neigh_info, (batch_size, -1))

            inp["ids"] = tf.concat([inp["ids"], neigh_info], axis=1)

            # If we are just concatenating everything then the main encoder needs
            # neighbors IDs.
            neigh_ids = tf.range(p.max_neighbors)[:, tf.newaxis]
            neigh_ids = tf.tile(
                neigh_ids,
                (batch_size, p.max_spelling_len + p.max_pronunciation_len))
            neigh_ids = tf.reshape(neigh_ids, (batch_size, -1))
            # Add the ids for the main input
            main_ids = tf.tile([[p.max_neighbors]],
                               (batch_size, p.max_spelling_len))
            inp["task_ids"] = tf.concat([main_ids, neigh_ids], axis=1)

        inp["paddings"] = self._GetPaddings(inp["ids"], dtype=tf.int32)
        enc_out = self.encoder.FProp(theta.encoder, py_utils.NestedMap(inp))

        # Auxiliary inputs that the decoder can attend to, currently can be
        # neighbour summaries.
        aux_inputs = []
        aux_paddings = []

        if p.use_neighbors and p.neigh_att_type != "CONCAT":
            neigh_enc, padding = self._GetAxiliaryNeighInputs(
                spellings, pronunciations, enc_out, theta, batch_size)

            aux_inputs.extend(neigh_enc)
            aux_paddings.extend(padding)

        if aux_inputs:
            aux_inputs = tf.concat(aux_inputs, axis=0)
            aux_paddings = tf.concat(aux_paddings, axis=0)

            if p.aux_dropout_prob and not self.do_eval:
                aux_inputs = tf.nn.dropout(
                    aux_inputs,
                    p.aux_dropout_prob,
                    noise_shape=(aux_inputs.get_shape().as_list()[0],
                                 batch_size, 1))

            enc_out.encoded = tf.concat([enc_out.encoded, aux_inputs], axis=0)
            enc_out.padding = tf.concat([enc_out.padding, aux_paddings],
                                        axis=0)

        enc_out.embedded_inputs = None  # to verify this is not used
        predictions = self.decoder.ComputePredictions(
            theta.decoder, enc_out,
            py_utils.NestedMap({
                "ids":
                pronunciation,
                "paddings":
                self._GetPaddings(pronunciation),
                "weights":
                tf.ones_like(input_batch.pronunciation, dtype=tf.float32),
            }))

        beam_out = self.decoder.BeamSearchDecode(enc_out, p.beam_size)
        top_ids = tf.reshape(beam_out.topk_ids,
                             [batch_size, -1, p.max_pronunciation_len])
        # Just take the top beam decodings
        top_ids = top_ids[:, 0, :]

        if p.is_inference:
            self.BuildInferenceInfo(top_ids, input_batch.pronunciation,
                                    enc_out)
            self.per_example_tensors["beam_scores"] = beam_out.topk_scores

        self.per_example_tensors["hyp"] = top_ids
        self.per_example_tensors["cognate_id"] = input_batch.cognate_id
        self.per_example_tensors["inp"] = input_batch.spelling
        self.per_example_tensors["ref"] = input_batch.pronunciation
        if p.use_neighbors:  # Note that cannot return None!
            self.per_example_tensors[
                "neighbor_spellings"] = input_batch.neighbor_spellings
            self.per_example_tensors[
                "neighbor_pronunciations"] = input_batch.neighbor_pronunciations
        self.prediction_values = predictions
        predictions.batch = input_batch

        return predictions
Ejemplo n.º 19
0
    def _StreamMoments(self, inputs, paddings, cached_sum, cached_count,
                       cached_var):
        """Computes mean and variance over the valid data points in inputs.

    Args:
      inputs: [B, T, F, N, G] or [B, T, N, G]
      paddings: [B, T, 1, 1, 1] or [B, T, 1, 1]
      cached_sum: [B, 1, 1, N, 1] or [B, 1, N, 1]
      cached_count: same shape as cached_sum.
      cached_var: same shape as cached_sum.

    Returns:
      mean: [B, T, 1, N, 1] or [B, T, N, 1]
      variance: same shape as mean.
      new_cached_sum: same shape as cached_sum.
      new_cached_count: same shape as cached_count.
    """
        tf.logging.vlog(1, 'inputs: %r', inputs)
        tf.logging.vlog(1, 'paddings: %r', paddings)
        tf.logging.vlog(1, 'cached_sum: %r', cached_sum)
        tf.logging.vlog(1, 'cached_count: %r', cached_count)

        mask = tf.cast(1.0 - paddings, inputs.dtype)
        inputs *= tf.cast(mask, inputs.dtype)

        input_rank = py_utils.GetRank(inputs)
        assert input_rank is not None, (f'inputs rank must be staic for '
                                        f'{repr(inputs)}')
        reduce_over_dims = list(range(input_rank))
        # Skip B, T, and N. Reduce {F,G} or just G.
        reduce_over_dims = reduce_over_dims[2:-2] + reduce_over_dims[-1:]
        tf.logging.vlog(1, 'reduce_over_dims: %s', reduce_over_dims)

        # [B, T, 1, N, 1] or [B, T, N, 1]
        sum_v = tf.reduce_sum(inputs, reduce_over_dims, keepdims=True)
        sum_v = tf.math.cumsum(sum_v, axis=1)
        sum_v += cached_sum

        # [B, T, 1, 1, 1] or [B, T, 1, 1]
        count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=True)
        count_v = tf.math.cumsum(count_v, axis=1)
        input_shape = py_utils.GetShape(inputs)
        if input_rank == 4:
            # F * G
            multiplier = input_shape[-1] * input_shape[-3]
        else:
            # G
            multiplier = input_shape[-1]
        count_v *= multiplier
        count_v += cached_count
        count_v = tf.maximum(count_v, 1.0)

        tf.logging.vlog(1, 'sum_v: %r', sum_v)
        tf.logging.vlog(1, 'count_v: %r', count_v)

        mean = sum_v / count_v
        if py_utils.FLAGS.tflite_compatible:
            # TfLite doesn't support broadcasting with 5D tensors.
            inputs_shape = py_utils.GetShape(inputs)
            if len(inputs_shape) == 4:
                tiled_mean = tf.tile(mean, [1, 1, 1, inputs_shape[3]])
            else:
                tiled_mean = tf.tile(
                    mean, [1, 1, inputs_shape[2], 1, inputs_shape[4]])
            sum_vv = tf.reduce_sum(tf.math.square(inputs - tiled_mean) * mask,
                                   reduce_over_dims,
                                   keepdims=True)
        else:
            sum_vv = tf.reduce_sum((inputs - mean)**2 * mask,
                                   reduce_over_dims,
                                   keepdims=True)
        sum_vv = tf.math.cumsum(sum_vv, axis=1)
        sum_vv += cached_var

        cached_sum = sum_v[:, -1:]
        cached_count = count_v[:, -1:]
        cached_var = sum_vv[:, -1:]

        variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(sum_vv, tf.cast(0, sum_vv.dtype)),
        ], sum_vv / count_v)
        return mean, variance, cached_sum, cached_count, cached_var
Ejemplo n.º 20
0
def ExtractBlockContextV2(x,
                          block_size,
                          left_context,
                          right_context,
                          padding_val=0.0,
                          paddings=None):
    """Extracts temporal context for every block (without restrictions).

  This is a generalized implementation of ExtractBlockContext, where block_size,
  left_context, and right_context are 3 free parameters and we don't have
  constraints (other than l>=1, r>=0, block_size>0).

  Args:
    x: a tensor of [batch, time, dim].
    block_size: int. Number of time frames in a block.
    left_context: int. Left context size. Note that the actual left context is
      `left_context - 1` (this is to be compatible with ExtractBlockContext
      implementation).
    right_context: int. Right context size.
    padding_val: float. value on the padded frames.
    paddings: optional. If specified, it must be a tensor of [batch, time], and
      we will return a padding tensor indicating padding info for the returned
      tensor.

  Returns:
    (x_patches, x_paddings) where

    - x_patches: A tensor of
      [batch, num_blocks, context_size + block_size, dim] with necessary
      paddings, where context_size = (left_context - 1) + right_context,
      and output[:, i, ...] are
      x[:, start-left_context+1:end+right_context, ...], where
      start = i * block_size, end = (i + 1) * block_size.
    - x_paddings: None if paddings = None; else a
      [batch, num_blocks, context_size + block_size] tensor, indicating the
      padding info for the corresponding position in x_patches.

  Let's define some variables here:

  B: batch size
  T: input tensor length in time axis
  D: input tensor dimension in the last axis
  W: block size
  U: ceil(T/W)
  L: left context size
  R: right context size
  C: L-1+W+R, full block length

  Given a [B, T, D] tensor, the return is a [B, U, C, D] tensor
  where ret[b, u, :] is a length of 2D tensor in a shape (L - 1 + W + R, D),
  which is a u-th block of the input tensor with (L - 1) left context frames
  and R right context frames.

  Implementation note:

  We use the following procedure to get the return tensor

  - first do padding in the beginning and at the end:
    [B, T, D] -> [B, L - 1 + U*W + L - 1 + R, D]
  - add one extra axis
    [B, L-1+U*W+R, D] -> [B, L-1+U*W+R, D, 1]
  - use gather to gather blocks
    [B, L-1+U*W+R+L-1, D, 1] -> [B, U, C, D]

  TODO(yqw): after verfiying correctness and benchmark, consider replace v1
  implementation?
  """
    # 0. basic shapes
    b, t, d = py_utils.GetShape(x, 3)
    w = block_size
    u = (t + w - 1) // w  # equivalent to math.ceil(t/w)
    l = left_context
    r = right_context
    c = l - 1 + r + w

    # the only constraints are block_size > 0 , l >= 1, r>=0
    if w <= 0:
        raise ValueError(f'block size ({w}) must be greater than 0')
    if l < 1:
        raise ValueError(f'Left context ({left_context}) must be >= 1.')
    if r < 0:
        raise ValueError(f'Right context ({right_context}) must be >= 0')
    if paddings is not None:
        paddings = py_utils.HasShape(paddings, [b, t])

    # 1. do front and rear padding
    left_pad = l - 1
    # we need to make sure all u * w elements have enough long context
    right_pad = (u * w - t + l - 1 + r)
    x_padded = _DoPadding(x,
                          b,
                          left_pad,
                          right_pad,
                          d,
                          padding_val=padding_val)
    if paddings is not None:
        paddings = _DoPadding(paddings,
                              b,
                              left_pad,
                              right_pad,
                              d=None,
                              padding_val=1.0)

    # 2. generate gather indices
    # gather_indices is a [u, c] matrix like
    #  [ 0, .........,             c-1]
    #  [ w, .........,       w + (c-1)]
    #  [2w, ..........,     2w + (c-1)]
    #  [(u-1)*w, ...., (u-1)*w + (c-1)]
    gather_indices = (tf.tile(tf.expand_dims(tf.range(0, c), axis=0), (u, 1)) +
                      tf.tile(tf.expand_dims(tf.range(0, u * w, w), axis=1),
                              (1, c)))

    # 3. generate x_patches, shape [b, u, c, d]
    x_patches = tf.gather(x_padded, gather_indices, axis=1)

    if paddings is not None:
        # gather is now a [b, u, c] tensor
        paddings = tf.gather(paddings, gather_indices, axis=1)

    return x_patches, paddings
Ejemplo n.º 21
0
 def TileForBeamAndFlatten(tensor):
   tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
   tensor = tf.tile(
       tensor, [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
   tgt_batch = tf.shape(step_ids)[0]  # num_hyps_per_beam*src_batch
   return tf.reshape(tensor, [tgt_batch])
Ejemplo n.º 22
0
    def RelPositionBias(self, content, abs_pos_emb, skip_term_b=False):
        """Compute relative position bias.

    This is a subroutine used by variants of self-attentions with relative
    positional embedding.

    output[b][n][i][j] = content[b][i][n] x abs_pos_emb[i-j+T-1][n]

    Padding should be masked by the caller of this function.

    B: batch size
    T: sequence length
    N: num of attention heads.
    H: per-head attention dimension.

    Args:
      tensors of the following shapes:
      content:         [N, H] if skip_term_b else [B, T, N, H]
      abs_pos_emb:     [2T - 1, N, H], the absolute positional embedding.
        abs_pos_emb[i] is the emb of relative distance i - (T-1).
      skip_term_b:     If to skip term_b in section 3.3 equation.

    Returns:
      The attention logits tensor. [N, T, T] if skip_term_b else [B, N, T, T].
    """
        if not skip_term_b:
            b, t, n, h = py_utils.GetShape(content)
            l = 2 * t - 1
            abs_pos_emb = py_utils.HasShape(abs_pos_emb, [l, n, h])
        else:
            n, h = py_utils.GetShape(content)
            l = py_utils.GetShape(abs_pos_emb)[0]
            t = (l + 1) // 2

        if not skip_term_b:
            # [B, N, T, L=2T-1]
            content, abs_pos_emb = self.ToAqtActActInputs(content, abs_pos_emb)
            term_bd = tf.einsum('BTNH,LNH->BNTL', content, abs_pos_emb)
            term_bd = self.FromAqtActActMatmul(term_bd)

            term_bd = tf.reshape(term_bd, [b, n, t * l], name='flatten')
            # [B, N, T * (L + 1)].
            term_bd = tf.pad(term_bd, ((0, 0), (0, 0), (0, t)))
            # [B, N, T, L + 1].
            term_bd = tf.reshape(term_bd, [b, n, t, l + 1], name='restore')
            return term_bd[:, :, :, t - 1::-1]
        else:
            # [N, L=2T-1]
            content, abs_pos_emb = self.ToAqtActActInputs(content, abs_pos_emb)
            term_d = tf.einsum('NH,LNH->NL', content, abs_pos_emb)
            term_d = self.FromAqtActActMatmul(term_d)

            # [N, T, L]
            term_d = tf.tile(tf.expand_dims(term_d, axis=1), [1, t, 1],
                             name='tile')
            term_d = tf.reshape(term_d, [n, t * l])
            # [N, T * (L + 1)].
            term_d = tf.pad(term_d, ((0, 0), (0, t)))
            # [N, T, L + 1].
            term_d = tf.reshape(term_d, [n, t, l + 1], name='restore')
            return term_d[:, :, t - 1::-1]
Ejemplo n.º 23
0
def ComputeSparseAttention(q, k, v, sparsity_indices, paddings=None):
  """Computes attention according to a sparsity pattern.

  We use the following capital letters to denote shape parameters:
    B = batch size
    S = length of the source sequence
    T = length of the target sequence
    N = number of attention heads
    H = dimensions of each attention head
    K = number of clusters
    W = attention window (K <= S)

  The 'sparsity_indices' is a tensor of integral type where the last dimension
  contains W indices (W is the attention window) for each corresponding position
  along S in 'k' that the query is allowed to attend to.

  For example, if sparsity_indices[batch_idx, target time step, head_idx] =
  [1, 7, 8], it means that token in the query attends to values with indices
  1, 7, and 8, and the attention window here is 3.

  The valid values in 'sparsity_indices' are [-1, S-1]. Note that the value -1
  is reserved to mean paddings, distinct from the value (S-1).

  For example, if W=S and 'sparsity_indices' contains range(S) on the last
  dimension, this degenerates to the original full attention.

  We require that 'sparsity_indices' does not contain duplicates (except for -1
  to indicate paddings), but we do not require 'sparsity_indices' to be sorted.

  Note that this implementation is flexible and generic but is not optimized for
  time or space complexity. Please consider grouping queries that attend to the
  same subset of values first for efficiency.

  Args:
    q: (projected) queries, [B, T, N, H];
    k: (projected) keys, [B, S, N, H];
    v: (projected) values, [B, S, N, H];
    sparsity_indices: [B, T, N, W], where W is the attention window;
    paddings: paddings for keys, [B, S] if not None.

  Returns:
    output: the encoded output, [B, T, N, H].
    atten_probs: the attention weights, [B, T, N, S].
  """
  q = tf.convert_to_tensor(q)
  k = tf.convert_to_tensor(k)
  v = tf.convert_to_tensor(v)
  sparsity_indices = tf.convert_to_tensor(sparsity_indices)

  k = py_utils.HasRank(k, 4)
  _, source_length, _, dim_per_head = py_utils.GetShape(k, 4)
  sparsity_indices = py_utils.HasRank(sparsity_indices, 4)
  batch_size, target_length, num_heads, attention_window = py_utils.GetShape(
      sparsity_indices, 4)
  py_utils.assert_less_equal(
      attention_window, source_length,
      'The provided sparsity_indices has attention window '
      ' > source length. This is likely an error.')

  # To prepare for gathering the relevant vectors from 'k', we prepare
  # gather_idx of shape [B, T, N, W, 3] where the last dimension corresponds to
  # slices in 'k' indexed by (batch index, source time step, head index),
  # where the source length index comes from the original W dimension in
  # 'sparsity_indices'.
  seq_idx = tf.expand_dims(sparsity_indices, axis=-1)
  # Overwrite the paddings -1 with valid gather indices (zeros). We will
  # fix the logits with -inf in these positions later.
  seq_idx = tf.where(seq_idx < 0, tf.zeros_like(seq_idx), seq_idx)
  batch_idx = tf.reshape(
      tf.range(0, batch_size, dtype=sparsity_indices.dtype),
      [batch_size, 1, 1, 1, 1])
  batch_idx = tf.tile(batch_idx,
                      [1, target_length, num_heads, attention_window, 1])
  head_idx = tf.reshape(
      tf.range(0, num_heads, dtype=sparsity_indices.dtype),
      [1, 1, num_heads, 1, 1])
  head_idx = tf.tile(head_idx,
                     [batch_size, target_length, 1, attention_window, 1])
  # [B, T, N, W, 3], where last dimension is (batch index, source length index,
  # head index).
  gather_idx = tf.concat([batch_idx, seq_idx, head_idx], axis=-1)

  # Both the gathered k and v have shape [B, T, N, W, H]
  k = tf.gather_nd(k, gather_idx)
  v = tf.gather_nd(v, gather_idx)

  if paddings is None:
    paddings = tf.zeros([batch_size, source_length])
  paddings = tf.convert_to_tensor(paddings)
  paddings = tf.expand_dims(paddings, axis=-1)
  # [B, S, N]
  paddings = tf.tile(paddings, [1, 1, num_heads])
  # [B, T, N, W]
  paddings = tf.gather_nd(paddings, gather_idx)

  logits = tf.einsum('BTNH, BTNWH -> BTNW', q, k)
  logits *= tf.math.rsqrt(tf.cast(dim_per_head, q.dtype))

  very_negative_logits = (
      tf.ones_like(logits) * logits.dtype.max *
      tf.constant(-0.7, dtype=logits.dtype))
  padded_logits = tf.where(
      tf.math.logical_or(sparsity_indices < 0, paddings > 0.0),
      very_negative_logits, logits)

  # [B, T, N, W]
  atten_probs = tf.nn.softmax(padded_logits, name='attention_weights')
  atten_probs = tf.where(sparsity_indices < 0, tf.zeros_like(logits),
                         atten_probs)
  output = tf.einsum('BTNW, BTNWH -> BTNH', atten_probs, v)

  # Scatter 'atten_probs' back into the original source length.
  # [B, T, N, W, 1]
  batch_idx = tf.tile(
      tf.range(batch_size)[:, None, None, None, None],
      [1, target_length, num_heads, attention_window, 1])
  # [B, T, N, W, 1]
  target_seq_idx = tf.tile(
      tf.range(target_length)[None, :, None, None, None],
      [batch_size, 1, num_heads, attention_window, 1])
  # [B, T, N, W, 1]
  head_idx = tf.tile(
      tf.range(num_heads)[None, None, :, None, None],
      [batch_size, target_length, 1, attention_window, 1])
  # seq_idx: [B, T, N, W, 1]
  # [B, T, N, W, 4]
  scatter_idx = tf.concat([batch_idx, target_seq_idx, head_idx, seq_idx], -1)
  # [B, T, N, S]
  scattered_probs = tf.scatter_nd(
      scatter_idx, atten_probs,
      [batch_size, target_length, num_heads, source_length])
  return output, scattered_probs
Ejemplo n.º 24
0
    def FProp(self,
              theta,
              x,
              x_paddings=None,
              eos_id=1,
              force_sample_last_token=True):
        """Applies SymbolInsertionLayer.

    We take in a `x`, which represents the groundtruth sequence (i.e., English
    sequence). We return a sampled rollin (observed) canvas (i.e., random subset
    of the English sequence), as well as the target (indices) for an
    insertion-based model (i.e., the targets given the random observed subset).

    Args:
      theta: Ignored, this can be None.
      x: The symbol ids of shape `[batch_size, time_dim]`.
      x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where
        0 is valid and 1 is invalid.
      eos_id: The <eos> token id to represent end-of-slot.
      force_sample_last_token: Set True to force sample the last token of `x`.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be
          equal.
        - canvas_indices: The canvas indices (into `x`).
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices of shape [num_targets, 3].
          `num_targets` is the number of total targets in the entire batch.
          [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2]
          captures the token. Each row [batch, slot, vocab] represents the
          indices of the target -- i.e., the batch, slot and vocab combination
          of the target. Typical usage of these indices is to tf.gather_nd
          the log-probs (from the softmax layer).
        - target_weights: The target weights.

    Raises:
      ValueError: If invalid params.
    """
        p = self.params

        batch_size = py_utils.GetShape(x)[0]
        time_dim = py_utils.GetShape(x)[1]

        if x_paddings is None:
            x_paddings = tf.zeros([batch_size, time_dim], tf.float32)

        oracle_policy = p.oracle_policy
        rollin_policy = (oracle_policy
                         if p.rollin_policy == 'oracle' else p.rollin_policy)

        if rollin_policy != 'uniform':
            raise ValueError('Unknown or unsupported rollin policy: %s' %
                             rollin_policy)
        if oracle_policy != 'uniform':
            raise ValueError('Unknown or unsupported oracle policy: %s' %
                             oracle_policy)

        x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)

        # Compute the desired length per example in the batch.
        ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed)
        if force_sample_last_token:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32),
                x_len - 1) + 1
        else:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32),
                x_len)
        # Compute the maximum length across the batch.
        c_len_max = tf.reduce_max(c_len)

        # Grab subset of random valid indices per example.
        z_logits = tf.cast(
            tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1),
            tf.float32) * -1e9
        if force_sample_last_token:
            # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can
            # accomplish this by add +LARGE_NUMBER to the logits.
            z_logits += tf.cast(
                tf.equal(tf.expand_dims(tf.range(time_dim), 0),
                         tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9
        # Gumbel-max trick to sample (we only sample valid positions per sample in
        # the batch).
        z = -tf.math.log(-tf.math.log(
            tf.random.uniform([batch_size, time_dim], seed=p.random_seed)))
        unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim)

        # Trim everything > c_len_max.
        c_indices = c_indices[:, :c_len_max]

        # Invalidate any indices >= c_len, we use the last index as the default
        # invalid index.
        c_indices = tf.where(
            tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1),
            c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1))

        # Materialize the canvas.
        c_indices = tf.sort(c_indices)
        c = tf.gather_nd(
            x,
            tf.stack([
                tf.reshape(
                    tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                            [1, c_len_max]), [-1]),
                tf.reshape(c_indices, [-1])
            ], 1))
        c = tf.reshape(c, [batch_size, c_len_max])

        # Compute the paddings.
        c_paddings = 1 - tf.sequence_mask(
            c_len, c_len_max, dtype=x_paddings.dtype)
        c *= tf.cast(1 - c_paddings, tf.int32)

        indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, c_len_max]), [batch_size * c_len_max, 1]),
            tf.reshape(c_indices, [batch_size * c_len_max, 1])
        ], 1)
        x_token_is_observed = tf.scatter_nd(
            indices, tf.ones([batch_size * c_len_max], tf.int32),
            py_utils.GetShape(x))
        # `x_segments` captures which slot each `x` belongs to (both observed and
        # tokens that need to be observed).
        x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True)

        x_token_is_observed = tf.cast(x_token_is_observed, tf.bool)
        prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1],
                                          [[0, 0], [1, 0]],
                                          constant_values=True)
        x_token_is_observed = tf.reshape(x_token_is_observed, [-1])
        prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1])
        x_is_valid = tf.cast(1 - x_paddings, tf.bool)
        x_is_valid = tf.reshape(x_is_valid, [-1])

        # Remap all the observed to <eos>, note some of these need a zero weight
        # (or else there would be <eos> and valid token in the same slot).
        target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32)
        target_indices = tf.where(
            x_token_is_observed,
            tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices)

        # TODO(williamchan): We give uniform 1.0 weight, however, math suggests
        # we may want to weigh this term by the original sequence length.
        target_weights = tf.ones_like(target_indices, tf.float32)

        # We need to set all the weights for <eos> which actually have valid tokens
        # in the slot to zero.
        target_weights = tf.where(
            x_token_is_observed & ~prev_x_token_is_observed,
            tf.zeros_like(target_weights), target_weights)

        # TODO(williamchan): Consider dropping the entries w/ weight zero.

        # Add the batch and slot indices.
        target_indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, time_dim]), [batch_size * time_dim, 1]),
            tf.reshape(x_segments, [-1, 1]), target_indices
        ], 1)

        # Select only the valid indices. The selected valid ones include slots w/
        # <eos>.
        target_indices = target_indices[x_is_valid]
        target_weights = target_weights[x_is_valid]

        return py_utils.NestedMap(canvas=c,
                                  canvas_indices=c_indices,
                                  canvas_paddings=c_paddings,
                                  target_indices=target_indices,
                                  target_weights=target_weights)
Ejemplo n.º 25
0
def NeighborhoodIndices(points,
                        query_points,
                        k,
                        points_padding=None,
                        max_distance=None,
                        sample_neighbors_uniformly=False):
    """Get indices to k-neighbors of query_points in points.

  Padding is returned along-side indices. Non-padded points are guaranteed to
  be unique (non-repeated) points from original non-padded points.

  Padded points arise due to either a lack of points (k exceeds the number
  of original non-padded points) or points are too far away (exceeds max
  distance).

  Note: Padded point indices may refer to padded points from the original, or
  may be duplicates of the closest point.

  TODO(weihan,jngiam): PointCNN implementation makes an assumption that padded
  points are repeated points from the original points. This behavior is
  maintained here, but we should update PointCNN to respect indices paddings.

  Args:
    points: tensor of shape [N, P1, dims].
    query_points: tensor of shape [N, P2, dims]
    k: Integer.
    points_padding: optional tensor of shape [N, P1] containing True/1.0 iff the
      point is a padded point. if None, then all points are considered real
      points.
    max_distance: float representing the maximum distance that each neighbor can
      be. If there are no points within the distance, then the closest point is
      returned (regardless of distance). If this is set to None, then no
      filtering by distance is performed.
    sample_neighbors_uniformly: boolean specifying whether to sample neighbors
      uniformly if they are within max distance.

  Returns:
    A pair of tensors:

    - indices: tensor of shape [N, P2, k].
    - padding: tensor of shape [N, P2, k] where 1 represents a padded point, and
      0 represents an unpadded (real) point.

  """
    n, p1 = py_utils.GetShape(points, 2)
    query_points = py_utils.HasShape(query_points, [n, -1, -1])
    _, p2 = py_utils.GetShape(query_points, 2)

    # Compute pair-wise squared distances.
    # Note that dist_mat contains the squared distance (without sqrt). Thus, when
    # using max_distance, we will need to square max_distance to make sure it's
    # in the same units.
    dist_mat = SquaredDistanceMatrix(query_points, points)
    dist_mat = py_utils.HasShape(dist_mat, [n, p2, p1])

    # Add a large scalar to the distances for padded points.
    # dist_mat[i, j, k] will be:
    #   if k < valid_num[i]: distance between points[i, k] and query_points[i, j]
    #   otherwise:           a large scalar added to dist_mat[i, j, k]
    if points_padding is not None:
        points_padding = tf.cast(tf.expand_dims(points_padding, 1), tf.float32)
        points_padding = py_utils.HasShape(points_padding, [n, 1, p1])
        large_scalar = tf.reduce_max(dist_mat) + 1
        dist_mat += points_padding * large_scalar

    # To perform sampling neighbors uniformly efficiently, we set all neighbors
    # that are within the distance threshold to have distances be drawn uniformly
    # at random. Using top_k with this enables selecting a random set quickly
    # without replacement.
    if sample_neighbors_uniformly:
        if max_distance is not None:
            mask_by_distance = tf.less_equal(dist_mat, max_distance**2)
            dist_mat = tf.where(
                mask_by_distance,
                tf.square(max_distance) *
                tf.random_uniform(tf.shape(dist_mat)), dist_mat)
        else:
            raise ValueError(
                'Uniform sampling requires specifying max_distance.')

    top_k_dist, indices = tf.nn.top_k(-dist_mat, k=k,
                                      sorted=True)  # N x P2 x K

    # Set padding using top_k_dist; padded points will have distance exceeding
    # the large_scalar.
    if points_padding is not None:
        paddings = tf.greater_equal(-top_k_dist, large_scalar)
    else:
        paddings = tf.zeros_like(top_k_dist, dtype=tf.bool)

    # Filter by max_distances by setting all indices that exceed the max_distance
    # to the closest point.
    if max_distance is not None:
        # Mask is true for points that are further than max_distance.
        mask_by_distance = tf.greater(-top_k_dist, tf.square(max_distance))
        closest_idx = tf.tile(indices[:, :, :1], [1, 1, k])
        indices = tf.where(mask_by_distance, closest_idx, indices)
        paddings |= mask_by_distance

    indices = tf.reshape(indices, [n, p2, k])
    paddings = tf.cast(paddings, tf.float32)

    return indices, paddings
    def testCausalConv2DLayerStridedWithPaddingFProp(self, seq_len):
        """Check strided convs get the same values for different length dim."""
        # TODO(isaace): THIS TEST SHOWS THAT THERE IS A BUG WITH PADDING
        with self.session(use_gpu=True) as sess:
            batch_size = 5
            expected_seq_len = 3

            params = conv_layers.CausalConv2DLayerWithPadding.Params()
            params.weight_norm = False
            params.filter_stride = [2, 2]
            params.name = 'conv'
            params.filter_shape = [3, 1, 1, 1]
            params.params_init = py_utils.WeightInit.Constant(1.0)
            conv_layer = params.Instantiate()

            # Set up the padding for the sequence length. (starting at 5).
            in_padding = tf.constant([
                [0, 0, 0, 0, 0],
                [0, 0, 0, 0, 1],
                [0, 0, 0, 1, 1],
                [0, 0, 1, 1, 1],
                [0, 1, 1, 1, 1],
            ], tf.float32)
            in_padding = tf.pad(in_padding, [[0, 0], [0, seq_len - 5]],
                                constant_values=1.0)

            inputs = 1.0 + tf.tile(
                tf.reshape(tf.range(seq_len, dtype=tf.float32),
                           [1, seq_len, 1, 1]), [batch_size, 1, 3, 1])
            inputs = py_utils.ApplyPadding(
                tf.reshape(in_padding, [batch_size, seq_len, 1, 1]), inputs)

            inputs = py_utils.Debug(inputs)

            output, out_padding = conv_layer.FPropDefaultTheta(
                inputs, in_padding)

            output = py_utils.Debug(output)
            out_padding = py_utils.Debug(out_padding)

            self.evaluate(tf.global_variables_initializer())
            output, out_padding = sess.run([output, out_padding])

            self.assertEqual((batch_size, expected_seq_len, 2, 1),
                             output.shape)
            self.assertAllClose([
                [0, 0, 1],
                [0, 0, 1],
                [0, 1, 1],
                [0, 1, 1],
                [1, 1, 1],
            ], out_padding)

            # NOTE: There is a bug in the output not being padded correctly.
            self.assertAllClose([
                [[[1], [1]], [[6], [6]], [[12], [12]]],
                [[[1], [1]], [[6], [6]], [[7], [7]]],
                [[[1], [1]], [[6], [6]], [[3], [3]]],
                [[[1], [1]], [[3], [3]], [[0], [0]]],
                [[[1], [1]], [[1], [1]], [[0], [0]]],
            ], output)
Ejemplo n.º 27
0
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs):
  """Merges beam search hyps from multiple decoders.

  Args:
    max_hyps_per_beam: the number of top hyps in the merged results. Must be
      less than or equal to total number of input hyps.
    beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share
      the same source_batch and max sequence length.

  Returns:
    A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per
    beam.
  """
  source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0]
  value_dict = {}
  for output in beam_search_outputs:
    hyps_per_beam = py_utils.with_dependencies([
        py_utils.assert_equal(source_batch,
                              tf.shape(output.topk_hyps)[0]),
    ],
                                               tf.shape(output.topk_hyps)[1])
    for k, v in six.iteritems(output._asdict()):
      if v is None:
        continue
      if k == 'done_hyps':
        v = tf.transpose(v)
      if k not in value_dict:
        value_dict[k] = []
      value_dict[k].append(tf.reshape(v, [source_batch, hyps_per_beam, -1]))

  # Concatenate the tensors along the 'num_hyps_per_beam' dimension.
  concatenated = {}
  for k, values in six.iteritems(value_dict):
    if len(values) != len(beam_search_outputs):
      raise ValueError('Incomplete values for %s: %s' %
                       (k, beam_search_outputs))
    concatenated[k] = tf.concat(values, axis=1)

  scores = concatenated['topk_scores']
  scores = tf.where(
      tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6),
      scores)
  scores = tf.squeeze(scores, -1)

  # Select top max_hyps_per_beam indices per beam.
  _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam)
  batch_ids = tf.tile(
      tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam])
  # [source_batch, max_hyps_per_beam, 2]
  gather_indices = tf.stack([batch_ids, top_indices], axis=-1)

  # Gather the merged top hyps according to 'gather_indices'.
  top = beam_search_outputs[0]._asdict()
  total_hyps = source_batch * max_hyps_per_beam
  for k, v in six.iteritems(concatenated):
    v = tf.gather_nd(v, gather_indices)
    if k == 'done_hyps':
      v = tf.transpose(tf.reshape(v, [total_hyps, -1]))
    elif k == 'topk_hyps':
      v = tf.reshape(v, [source_batch, max_hyps_per_beam])
    elif k == 'topk_ids':
      v = tf.reshape(v, [total_hyps, -1])
    elif k in ('topk_lens', 'topk_scores', 'topk_decoded'):
      v = tf.reshape(v, [total_hyps])
    else:
      raise ValueError('Unexpected field: %s' % k)
    top[k] = v
  return BeamSearchDecodeOutput(**top)
Ejemplo n.º 28
0
    def ComputePredictions(self, theta, input_batch):
        """Computes predictions for `input_batch`.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      input_batch: A `.NestedMap` expected to contain cell_center_xyz,
        cell_points_xyz, cell_feature, anchor_bboxes,
        anchor_localization_residuals, assigned_gt_labels, and
        assigned_cls_mask. See class doc string for details.

    Returns:
      A `.NestedMap` object containing residuals and classification_logits.
    """
        p = self.params
        input_batch.Transform(lambda x:
                              (x.shape, x.shape.num_elements())).VLog(
                                  1, 'input_batch shapes: ')

        cell_feature = py_utils.HasRank(input_batch.cell_feature, 4)
        batch_size, num_centers, num_points_per_cell = py_utils.GetShape(
            cell_feature, 3)

        cell_points_xyz = py_utils.HasShape(
            input_batch.cell_points_xyz,
            [batch_size, num_centers, num_points_per_cell, 3])
        cell_center_xyz = py_utils.HasShape(input_batch.cell_center_xyz,
                                            [batch_size, num_centers, 3])

        cell_points_padding = py_utils.HasShape(
            input_batch.cell_points_padding,
            [batch_size, num_centers, num_points_per_cell])

        # TODO(jngiam): Make concat_feature computation a layer or configureable.
        cell_center_xyz = tf.reshape(cell_center_xyz,
                                     [batch_size, num_centers, 1, 3])
        centered_cell_points_xyz = cell_points_xyz - cell_center_xyz
        concat_feature = tf.concat([
            tf.tile(cell_center_xyz, [1, 1, num_points_per_cell, 1]),
            centered_cell_points_xyz, cell_feature
        ],
                                   axis=-1)  # pyformat: disable

        # Featurize point clouds at each center.
        point_input = py_utils.NestedMap({
            'points': centered_cell_points_xyz,
            'features': concat_feature,
            'padding': cell_points_padding,
        })
        featurized_cell = self.cell_featurizer.FProp(theta.cell_featurizer,
                                                     point_input)
        featurized_cell = py_utils.HasShape(featurized_cell,
                                            [batch_size, num_centers, -1])

        # Predict localization residuals.
        predicted_residuals = self.localization_regressor.FProp(
            theta.localization_regressor, featurized_cell)
        predicted_residuals = tf.reshape(
            predicted_residuals,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])

        if p.squash_rotation_predictions:
            predicted_rotations = predicted_residuals[..., 6:]
            predicted_rotations = np.pi * tf.tanh(predicted_rotations)
            predicted_residuals = tf.concat(
                [predicted_residuals[..., :6], predicted_rotations], axis=-1)

        # Predict object classification at each bbox.
        predicted_classification_logits = self.classifier.FProp(
            theta.classifier, featurized_cell)
        predicted_classification_logits = tf.reshape(
            predicted_classification_logits, [
                batch_size, num_centers, p.num_anchor_bboxes_per_center,
                p.num_classes
            ])

        return py_utils.NestedMap({
            'residuals':
            predicted_residuals,
            'classification_logits':
            predicted_classification_logits,
        })
Ejemplo n.º 29
0
    def BuildDataSource(self, data_source_from_file_pattern_fn):
        """Read and return input batch from a p.file_pattern list.

    `p.file_patterns` is a list of file patterns, `p.weights` contains
    weights for each file pattern.  If provided `p.bprop_variable_filters`
    includes a bprop_variable_filter for each file pattern.

    Args:
      data_source_from_file_pattern_fn: a function that takes file_pattern as an
        argument and returns an input batch.

    Returns:
      A NestedMap containing:
        data: a tuple of tf.Tensor or `.NestedMap` of tf.Tensor
        source_selected: a tensor of size [batch_size, number of data sources]
        selected_bprop: a tensor of size [number of data sources]
        bprop_variable_filters: containing a list of bprop_variable filters for
        each source

    Raises:
      ValueError: If unknown token type.
    """
        p = self.params

        def _MakeDataSourceFromFilePatternFunc(
                data_source_from_file_pattern_fn, file_pattern):
            # It's important to invoke self._DataSourceFromFilePattern() inside the
            # lambda to make sure that the record is drawn from data source
            # only if it will be used. Weights are handled by MixByWeight, not the
            # data_source_from_file_pattern_fn.
            return lambda: data_source_from_file_pattern_fn(file_pattern)

        if len(p.weights) != len(p.file_patterns):
            raise ValueError(
                'Expected p.file_patterns and p.weights to be the same length. '
                'Found %d file_patterns, and %d weights' %
                (len(p.file_patterns), len(p.weights)))
        if not all(isinstance(x, str) for x in p.file_patterns):
            raise ValueError(
                'Expected all elements of p.file_patterns to be strings')

        # TODO(rosenberg) replace this with functools.partial
        inputs = [
            _MakeDataSourceFromFilePatternFunc(
                data_source_from_file_pattern_fn, file_pattern)
            for file_pattern in p.file_patterns
        ]
        weights = p.weights
        if not p.bprop_variable_filters:
            bprop_variable_filters = [''] * len(inputs)
        else:
            bprop_variable_filters = p.bprop_variable_filters

        data_source, selected_bprop = py_utils.MixByWeight(inputs,
                                                           weights,
                                                           seed=p.random_seed)
        # TODO(neerajgaur): Remove _bprop_onehot and change code that uses it to
        # use source_selected from input_batch.
        batch_size = py_utils.GetShape(tf.nest.flatten(data_source)[0])[0]
        ret = py_utils.NestedMap()
        ret.data = data_source
        ret.bprop_variable_filters = bprop_variable_filters
        ret.selected_bprop = selected_bprop
        ret.source_selected = tf.tile(tf.expand_dims(selected_bprop, 0),
                                      [batch_size, 1])
        return ret
Ejemplo n.º 30
0
 def _BroadcastAcrossPoints(z):
     return tf.transpose(tf.tile(z, [1, num_points]))