Beispiel #1
0
def proto_maml_fc_layer_init_fn(labels, embeddings, weights, biases,
                                prototype_multiplier):
    """Return a list of operations for reparameterized ProtoNet initialization."""

    # This is robust to classes missing from the training set, but assumes that
    # the last class is present.
    num_ways = tf.cast(
        tf.math.reduce_max(input_tensor=tf.unique(labels)[0]) + 1, tf.int32)

    # When there are no examples for a given class, we default its prototype to
    # zeros, per the implementation of `tf.math.unsorted_segment_mean`.
    prototypes = tf.math.unsorted_segment_mean(embeddings, labels, num_ways)

    # Scale the prototypes, which acts as a regularizer on the weights and biases.
    prototypes *= prototype_multiplier

    # logit = -<squared Euclidian distance to prototype>
    #       = -(x - p)^T.(x - p)
    #       = 2 x^T.p - p^T.p - x^T.x
    #       = x^T.w + b
    #         where w = 2p, b = -p^T.p
    output_weights = tf.transpose(a=2 * prototypes)
    output_biases = -tf.reduce_sum(input_tensor=prototypes * prototypes,
                                   axis=1)

    # We zero-pad to align with the original weights and biases.
    output_weights = tf.pad(tensor=output_weights,
                            paddings=[[0, 0],
                                      [
                                          0,
                                          tf.shape(input=weights)[1] -
                                          tf.shape(input=output_weights)[1]
                                      ]],
                            mode='CONSTANT',
                            constant_values=0)
    output_biases = tf.pad(tensor=output_biases,
                           paddings=[[
                               0,
                               tf.shape(input=biases)[0] -
                               tf.shape(input=output_biases)[0]
                           ]],
                           mode='CONSTANT',
                           constant_values=0)

    return [
        weights.assign(output_weights),
        biases.assign(output_biases),
    ]
Beispiel #2
0
def conv1d(x, filters, kernel_size, strides=1, padding='causal', dilation_rate=1, act=None,
           init=None, scope="conv1d", use_bias=True):
    batch_size, seq_len, h = x.get_shape().as_list()
    # Taken from keras, there is a faster version from magenta
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        # assert seq_len % dilation_rate == 0

        w = tf.get_variable('kernel', shape=(kernel_size, h, filters), dtype=tf.float32, initializer=init)

        if padding == 'causal':
            # causal (dilated) convolution:
            left_pad = dilation_rate * (kernel_size - 1)
            pattern = [[0, 0], [left_pad, 0], [0, 0]]
            x = tf.pad(x, pattern)
            padding = 'VALID'

        out = tf.nn.convolution(
            input=x,
            filter=w,
            dilation_rate=(dilation_rate,),
            strides=(strides,),
            padding=padding)
        if use_bias:
            b = tf.get_variable('bias', shape=(filters,), dtype=tf.float32, initializer=tf.initializers.zeros)
            out = tf.add(out, b)
        if act is not None:
            return act(out)
    return out
Beispiel #3
0
    def __call__(self, example_string):
        """Processes a single example string.

    Extracts and processes the image, and ignores the label. We assume that the
    image has three channels.

    Args:
      example_string: str, an Example protocol buffer.

    Returns:
      image_rescaled: the image, resized to `image_size x image_size` and
      rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values
      to go beyond this range.
    """
        image_decoded = read_example_and_parse_image(example_string)['image']
        image_resized = tf.image.resize_images(
            image_decoded, [self.image_size, self.image_size],
            method=tf.image.ResizeMethod.BILINEAR,
            align_corners=True)
        image_resized = tf.cast(image_resized, tf.float32)
        image = 2 * (image_resized / 255.0 - 0.5)  # Rescale to [-1, 1].

        if self.data_augmentation is not None:
            if self.data_augmentation.enable_gaussian_noise:
                image = image + tf.random_normal(tf.shape(
                    image)) * self.data_augmentation.gaussian_noise_std

            if self.data_augmentation.enable_jitter:
                j = self.data_augmentation.jitter_amount
                paddings = tf.constant([[j, j], [j, j], [0, 0]])
                image = tf.pad(image, paddings, 'REFLECT')
                image = tf.image.random_crop(
                    image, [self.image_size, self.image_size, 3])

        return image
Beispiel #4
0
 def fn_knn_graph_from_points_unbatched(i):
   """Computes knn graph for example i in the batch."""
   num_valid_points_i = num_valid_points[i]
   points_i = points[i, :num_valid_points_i, :]
   if mask is None:
     mask_i = None
   else:
     mask_i = mask[i, :num_valid_points_i]
   distances_i, indices_i = knn_graph_from_points_unbatched(
       points=points_i,
       k=k,
       distance_upper_bound=distance_upper_bound,
       mask=mask_i)
   distances_i = tf.pad(
       distances_i, paddings=[[0, num_points - num_valid_points_i], [0, 0]])
   indices_i = tf.pad(
       indices_i, paddings=[[0, num_points - num_valid_points_i], [0, 0]])
   return distances_i, indices_i
Beispiel #5
0
def per_voxel_point_sample_segment_func(data, segment_ids, num_segments,
                                        num_samples_per_voxel):
    """Samples features from the points within each voxel.

  Args:
    data: A tf.float32 tensor of size [N, F].
    segment_ids: A tf.int32 tensor of size [N].
    num_segments: Number of segments.
    num_samples_per_voxel: Number of features to sample per voxel. If the voxel
      has less number of points in it, the point features will be padded by 0.

  Returns:
    A tf.float32 tensor of size [num_segments, num_samples_per_voxel, F].
    A tf.int32 indices of size [N, num_samples_per_voxel].
  """
    num_channels = data.get_shape().as_list()[1]
    if num_channels is None:
        raise ValueError('num_channels is None.')
    n = tf.shape(segment_ids)[0]

    def _body_fn(i, indices_range, indices):
        """Computes the indices of the i-th point feature in each segment."""
        indices_i = tf.math.unsorted_segment_max(data=indices_range,
                                                 segment_ids=segment_ids,
                                                 num_segments=num_segments)
        indices_i_positive_mask = tf.greater(indices_i, 0)
        indices_i_positive = tf.boolean_mask(indices_i,
                                             indices_i_positive_mask)
        boolean_mask = tf.scatter_nd(indices=tf.cast(tf.expand_dims(
            indices_i_positive - 1, axis=1),
                                                     dtype=tf.int64),
                                     updates=tf.ones_like(indices_i_positive,
                                                          dtype=tf.int32),
                                     shape=(n, ))
        indices_range *= (1 - boolean_mask)
        indices_i *= tf.cast(indices_i_positive_mask, dtype=tf.int32)
        indices_i = tf.pad(tf.expand_dims(indices_i, axis=1),
                           paddings=[[0, 0],
                                     [i, num_samples_per_voxel - i - 1]])
        indices += indices_i
        i = i + 1
        return i, indices_range, indices

    cond = lambda i, indices_range, indices: i < num_samples_per_voxel

    (_, _, indices) = tf.while_loop(
        cond=cond,
        body=_body_fn,
        loop_vars=(tf.constant(0, dtype=tf.int32), tf.range(n) + 1,
                   tf.zeros([num_segments, num_samples_per_voxel],
                            dtype=tf.int32)))

    data = tf.pad(data, paddings=[[1, 0], [0, 0]])
    voxel_features = tf.gather(data, tf.reshape(indices, [-1]))
    return tf.reshape(voxel_features,
                      [num_segments, num_samples_per_voxel, num_channels])
 def fn(i):
   num_valid_voxels_i = num_valid_voxels[i]
   num_valid_points_i = num_valid_points[i]
   voxel_features_i = voxel_features[i, :num_valid_voxels_i, :]
   segment_ids_i = segment_ids[i, :num_valid_points_i]
   point_features = tf.gather(voxel_features_i, segment_ids_i)
   point_features_rank = len(point_features.shape_as_list())
   point_features_paddings = [[0, num_points - num_valid_points_i]]
   for _ in range(point_features_rank - 1):
     point_features_paddings.append([0, 0])
   point_features = tf.pad(point_features, paddings=point_features_paddings)
   return point_features
Beispiel #7
0
def crop_and_pad_voxels(voxels, start_coordinates, end_coordinates):
    """Crops a voxel region and pads past the boundaries with zeros.

  This accepts start and end coordinates past the limits of the voxel grid,
  and uses it to calculate how much top/left/right/bottom padding to add.

  Args:
    voxels: A tf.float32 tensor of shape [x, y, z, f] to crop
    start_coordinates: A list of len 4 with the [x, y, z, f] starting location
      of our crop. This can be negative, which indicates left/top padding.
    end_coordinates: A list of len 4 with the [x, y, z, f] ending location of
      our crop. This can be beyond the size of the voxel tensor, which indicates
      padding.

  Returns:
    cropped_and_padded_voxels: A voxel grid with shape
      [end_coordinates[0] - start_coordinates[0],
       end_coordinates[1] - start_coordinates[1],
       end_coordinates[2] - start_coordinates[2],
       end_coordinates[3] - start_coordinates[3]]
  Raises:
    ValueError: If requested crop and pad is outside the bounds of what the
      function supports.
  """
    if len(start_coordinates) != 4:
        raise ValueError('start_coordinates should be of length 4')
    if len(end_coordinates) != 4:
        raise ValueError('end_coordinates should be of length 4')
    if any([coord <= 0 for coord in end_coordinates]):
        raise ValueError('Requested end coordinates should be > 0')

    start_coordinates = tf.convert_to_tensor(start_coordinates, tf.int32)
    end_coordinates = tf.convert_to_tensor(end_coordinates, tf.int32)

    # Clip the coordinates to within the voxel grid
    clipped_start_coordinates = tf.maximum(0, start_coordinates)
    clipped_end_coordinates = tf.minimum(voxels.shape, end_coordinates)

    cropped_voxels = tf.slice(voxels,
                              begin=clipped_start_coordinates,
                              size=(clipped_end_coordinates -
                                    clipped_start_coordinates))

    top_and_left_padding = tf.maximum(0, -start_coordinates)
    bottom_and_right_padding = tf.maximum(0, end_coordinates - voxels.shape)

    padding = tf.stack([top_and_left_padding, bottom_and_right_padding],
                       axis=1)
    return tf.pad(cropped_voxels, padding)
Beispiel #8
0
  def proto_maml_fc_bias(self, prototypes, zero_pad_to_max_way=False):
    """Computes the Prototypical MAML fc layer's bias.

    Args:
      prototypes: Tensor of shape [num_classes, embedding_size]
      zero_pad_to_max_way: Whether to zero padd to max num way.

    Returns:
      fc_bias: Tensor of shape [num_classes] or [self.logit_dim]
        when zero_pad_to_max_way is True.
    """
    fc_bias = -tf.square(tf.norm(prototypes, axis=1))
    if zero_pad_to_max_way:
      paddings = [[0, self.logit_dim - tf.shape(fc_bias)[0]]]
      fc_bias = tf.pad(fc_bias, paddings, 'CONSTANT', constant_values=0)
    return fc_bias
Beispiel #9
0
 def fn_normals_single_batch(i):
     """Function for computing normals for a single batch."""
     num_valid_points_i = num_valid_points[i]
     points_i = points[i, 0:num_valid_points_i, :]
     if viewpoints is None:
         viewpoint_i = None
     else:
         viewpoint_i = viewpoints[i, :]
     normals_i = points_to_normals_unbatched(
         points=points_i,
         k=k,
         distance_upper_bound=distance_upper_bound,
         viewpoint=viewpoint_i,
         noise_magnitude=noise_magnitude,
         method=method)
     return tf.pad(normals_i,
                   paddings=[[0, num_points - num_valid_points_i], [0, 0]])
def compute_semantic_labels(inputs, points_key, box_margin=0.1):
  """Computes ground-truth semantic labels of the points.

  If a point falls inside an object box, assigns it to the label of that box.
  Otherwise the point is assigned to background (unknown) which is label 0.

  Args:
    inputs: A dictionary containing points and objects.
    points_key: A string corresponding to the tensor of point positions in
      inputs.
    box_margin: A margin by which object boxes are grown. Useful to make sure
      points on the object box boundary fall inside the object.

  Returns:
    A tf.int32 tensor of size [num_points, 1] containing point semantic labels.

  Raises:
    ValueError: If the required object or point keys are not in inputs.
  """
  if points_key not in inputs:
    raise ValueError(('points_key: %s not in inputs.' % points_key))
  if 'objects/shape/dimension' not in inputs:
    raise ValueError('`objects/shape/dimension` not in inputs.')
  if 'objects/pose/R' not in inputs:
    raise ValueError('`objects/pose/R` not in inputs.')
  if 'objects/pose/t' not in inputs:
    raise ValueError('`objects/pose/t` not in inputs.')
  if 'objects/category/label' not in inputs:
    raise ValueError('`objects/category/label` not in inputs.')
  point_positions = inputs[points_key]
  boxes_length = inputs['objects/shape/dimension'][:, 0:1]
  boxes_width = inputs['objects/shape/dimension'][:, 1:2]
  boxes_height = inputs['objects/shape/dimension'][:, 2:3]
  boxes_rotation_matrix = inputs['objects/pose/R']
  boxes_center = inputs['objects/pose/t']
  boxes_label = tf.expand_dims(inputs['objects/category/label'], axis=1)
  boxes_label = tf.pad(boxes_label, paddings=[[1, 0], [0, 0]])
  points_box_index = box_utils.map_points_to_boxes(
      points=point_positions,
      boxes_length=boxes_length,
      boxes_height=boxes_height,
      boxes_width=boxes_width,
      boxes_rotation_matrix=boxes_rotation_matrix,
      boxes_center=boxes_center,
      box_margin=box_margin)
  return tf.gather(boxes_label, points_box_index + 1)
Beispiel #11
0
  def proto_maml_fc_weights(self, prototypes, zero_pad_to_max_way=False):
    """Computes the Prototypical MAML fc layer's weights.

    Args:
      prototypes: Tensor of shape [num_classes, embedding_size]
      zero_pad_to_max_way: Whether to zero padd to max num way.

    Returns:
      fc_weights: Tensor of shape [embedding_size, num_classes] or
        [embedding_size, self.logit_dim] when zero_pad_to_max_way is True.
    """
    fc_weights = 2 * prototypes
    fc_weights = tf.transpose(fc_weights)
    if zero_pad_to_max_way:
      paddings = [[0, 0], [0, self.logit_dim - tf.shape(fc_weights)[1]]]
      fc_weights = tf.pad(fc_weights, paddings, 'CONSTANT', constant_values=0)
    return fc_weights
def process_example(example_string, image_size, data_augmentation=None):
    """Processes a single example string.

  Extracts and processes the image, and ignores the label. We assume that the
  image has three channels.

  Args:
    example_string: str, an Example protocol buffer.
    image_size: int, desired image size. The extracted image will be resized to
      `[image_size, image_size]`.
    data_augmentation: A DataAugmentation object with parameters for perturbing
      the images.

  Returns:
    image_rescaled: the image, resized to `image_size x image_size` and rescaled
      to [-1, 1]. Note that Gaussian data augmentation may cause values to
      go beyond this range.
  """
    image_string = tf.parse_single_example(example_string,
                                           features={
                                               'image':
                                               tf.FixedLenFeature(
                                                   [], dtype=tf.string),
                                               'label':
                                               tf.FixedLenFeature([], tf.int64)
                                           })['image']
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image_resized = tf.image.resize_images(
        image_decoded, [image_size, image_size],
        method=tf.image.ResizeMethod.BILINEAR,
        align_corners=True)
    image = 2 * (image_resized / 255.0 - 0.5)  # Rescale to [-1, 1].

    if data_augmentation is not None:
        if data_augmentation.enable_gaussian_noise:
            image = image + tf.random_normal(
                tf.shape(image)) * data_augmentation.gaussian_noise_std

        if data_augmentation.enable_jitter:
            j = data_augmentation.jitter_amount
            paddings = tf.constant([[j, j], [j, j], [0, 0]])
            image = tf.pad(image, paddings, 'REFLECT')
            image = tf.image.random_crop(image, [image_size, image_size, 3])

    return image
 def _body_fn(i, indices_range, indices):
   """Computes the indices of the i-th point feature in each segment."""
   indices_i = tf.math.unsorted_segment_max(
       data=indices_range, segment_ids=segment_ids, num_segments=num_segments)
   indices_i_positive_mask = tf.greater(indices_i, 0)
   indices_i_positive = tf.boolean_mask(indices_i, indices_i_positive_mask)
   boolean_mask = tf.scatter_nd(
       indices=tf.cast(
           tf.expand_dims(indices_i_positive - 1, axis=1), dtype=tf.int64),
       updates=tf.ones_like(indices_i_positive, dtype=tf.int32),
       shape=(n,))
   indices_range *= (1 - boolean_mask)
   indices_i *= tf.cast(indices_i_positive_mask, dtype=tf.int32)
   indices_i = tf.pad(
       tf.expand_dims(indices_i, axis=1),
       paddings=[[0, 0], [i, num_samples_per_voxel - i - 1]])
   indices += indices_i
   i = i + 1
   return i, indices_range, indices
Beispiel #14
0
def _transfer_object_properties_to_points(inputs):
    """Sets the object properties for the points that fall inside objects.

  Args:
    inputs: A dictionary containing input tensors.
  """
    dic = {}
    if standard_fields.InputDataFields.objects_class in inputs:
        dic[standard_fields.InputDataFields.object_class_points] = inputs[
            standard_fields.InputDataFields.objects_class]
    if standard_fields.InputDataFields.objects_center in inputs:
        dic[standard_fields.InputDataFields.object_center_points] = inputs[
            standard_fields.InputDataFields.objects_center]
    if standard_fields.InputDataFields.objects_length in inputs:
        dic[standard_fields.InputDataFields.object_length_points] = inputs[
            standard_fields.InputDataFields.objects_length]
    if standard_fields.InputDataFields.objects_height in inputs:
        dic[standard_fields.InputDataFields.object_height_points] = inputs[
            standard_fields.InputDataFields.objects_height]
    if standard_fields.InputDataFields.objects_width in inputs:
        dic[standard_fields.InputDataFields.object_width_points] = inputs[
            standard_fields.InputDataFields.objects_width]
    if standard_fields.InputDataFields.objects_rotation_matrix in inputs:
        dic[standard_fields.InputDataFields.
            object_rotation_matrix_points] = inputs[
                standard_fields.InputDataFields.objects_rotation_matrix]

    for key, value in dic.items():
        if len(value.get_shape().as_list()) == 1:
            paddings = [[1, 0]]
        elif len(value.get_shape().as_list()) == 2:
            paddings = [[1, 0], [0, 0]]
        elif len(value.get_shape().as_list()) == 3:
            paddings = [[1, 0], [0, 0], [0, 0]]
        else:
            raise ValueError(('Invalid shape for %s' % key))
        temp_tensor = tf.pad(value, paddings=paddings)
        id_mapping = tf.reshape(
            inputs[standard_fields.InputDataFields.object_instance_id_points],
            [-1])
        inputs[key] = tf.gather(temp_tensor, id_mapping)
 def fn(i):
   """Map function."""
   num_valid_points_i = num_valid_points[i]
   points_i = points[i, :num_valid_points_i, :]
   features_i = features[i, :num_valid_points_i, :]
   voxel_features_i, voxel_indices_i, segment_ids_i, voxel_start_location_i = (
       pointcloud_to_sparse_voxel_grid_unbatched(
           points=points_i,
           features=features_i,
           grid_cell_size=grid_cell_size,
           segment_func=segment_func))
   num_valid_voxels_i = tf.shape(voxel_features_i)[0]
   (voxel_features_i, voxel_indices_i, num_valid_voxels_i,
    segment_ids_i) = _pad_or_clip_voxels(
        voxel_features=voxel_features_i,
        voxel_indices=voxel_indices_i,
        num_valid_voxels=num_valid_voxels_i,
        segment_ids=segment_ids_i,
        voxels_pad_or_clip_size=voxels_pad_or_clip_size)
   segment_ids_i = tf.pad(
       segment_ids_i, paddings=[[0, num_points - num_valid_points_i]])
   return (voxel_features_i, voxel_indices_i, num_valid_voxels_i,
           segment_ids_i, voxel_start_location_i)
def preprocess(inputs,
               output_keys=None,
               is_training=False,
               using_sequence_dataset=False,
               num_frame_to_load=1,
               transform_points_fn=None,
               image_preprocess_fn_dic=None,
               images_points_correspondence_fn=None,
               compute_semantic_labels_fn=None,
               compute_motion_labels_fn=None,
               view_names=(),
               points_key='points',
               colors_key='colors',
               normals_key='normals',
               intensities_key='intensities',
               elongations_key='elongations',
               semantic_labels_key='semantic_labels',
               motion_labels_key='motion_labels',
               spin_coords_key=None,
               points_in_image_frame_key=None,
               num_points_to_randomly_sample=None,
               x_min_degree_rotation=None,
               x_max_degree_rotation=None,
               y_min_degree_rotation=None,
               y_max_degree_rotation=None,
               z_min_degree_rotation=None,
               z_max_degree_rotation=None,
               points_pad_or_clip_size=None,
               voxels_pad_or_clip_size=None,
               voxel_grid_cell_size=(0.1, 0.1, 0.1),
               num_offset_bins_x=4,
               num_offset_bins_y=4,
               num_offset_bins_z=4,
               point_feature_keys=('point_offsets', ),
               point_to_voxel_segment_func=tf.math.unsorted_segment_mean,
               x_random_crop_size=None,
               y_random_crop_size=None,
               min_scale_ratio=None,
               max_scale_ratio=None,
               semantic_labels_offset=0,
               ignore_labels=(),
               remove_unlabeled_images_and_points=False,
               labeled_view_name=None,
               only_keep_first_return_lidar_points=False):
    """Preprocesses a dictionary of `Tensor` inputs.

  If is_training=True, it will randomly rotate the points around the z axis,
  and will randomly flip the points with respect to x and/or y axis.

  Note that the preprocessor function does not correct normal vectors if they
  exist in the inputs.
  Note that the preprocessing effects all values of `inputs` that are `Tensors`.

  Args:
    inputs: A dictionary of inputs. Each value must be a `Tensor`.
    output_keys: Either None, or a list of strings containing the keys in the
      dictionary that is returned by the preprocess function.
    is_training: Whether we're training or testing.
    using_sequence_dataset: if true, the inputs will contain scene and multiple
      frames data.
    num_frame_to_load: If greater than 1, load multiframe point cloud point
      positions and its correspondence.
    transform_points_fn: Fn to transform other frames to a specific frame's
      coordinate.
    image_preprocess_fn_dic: Image preprocessing function. Maps view names to
      their image preprocessing functions. Set it to None, if there are no
      images to preprocess or you are not interested in preprocesing images.
    images_points_correspondence_fn: The function that computes correspondence
      between images and points.
    compute_semantic_labels_fn: If not None, semantic labels will be computed
      using this function.
    compute_motion_labels_fn: If not None, motion labels will be computed using
      this function.
    view_names: Names corresponding to 2d views of the scene.
    points_key: The key used for `points` in the inputs.
    colors_key: The key used for `colors` in the inputs.
    normals_key: The key used for 'normals' in the inputs.
    intensities_key: The key used for 'intensities' in the inputs.
    elongations_key: The key used for 'elongations' in the inputs.
    semantic_labels_key: The key used for 'semantic_labels' in the inputs.
    motion_labels_key: The key used for 'motion_labels' in the inputs.
    spin_coords_key: The key used for 'spin_coords' in the inputs. In Waymo
      data, spin_coords is a [num_points, 3] tensor that contains scan_index,
      shot_index, return_index. In Waymo data, return_index of the first return
      points is 0.
    points_in_image_frame_key: A string that identifies the tensor that contains
      the points_in_image_frame tensor. If None, it won't be used.
    num_points_to_randomly_sample: Number of points to randomly sample. If None,
      it will keep the original points and does not perform sampling.
    x_min_degree_rotation: Min degree of rotation around the x axis.
    x_max_degree_rotation: Max degree of ratation around the x axis.
    y_min_degree_rotation: Min degree of rotation around the y axis.
    y_max_degree_rotation: Max degree of ratation around the y axis.
    z_min_degree_rotation: Min degree of rotation around the z axis.
    z_max_degree_rotation: Max degree of ratation around the z axis.
    points_pad_or_clip_size: Number of target points to pad or clip to. If None,
      it will not perform the point padding.
    voxels_pad_or_clip_size: Number of target voxels to pad or clip to. If None,
      it will not perform the voxel padding.
    voxel_grid_cell_size: A three dimensional tuple determining the voxel grid
      size.
    num_offset_bins_x: Number of bins for point offsets in x direction.
    num_offset_bins_y: Number of bins for point offsets in y direction.
    num_offset_bins_z: Number of bins for point offsets in z direction.
    point_feature_keys: The keys used to form the voxel features.
    point_to_voxel_segment_func: The function used to aggregate the features
      of the points that fall in the same voxel.
    x_random_crop_size: Size of the random crop in x dimension. If None, random
      crop will not take place on x dimension.
    y_random_crop_size: Size of the random crop in y dimension. If None, random
      crop will not take place on y dimension.
    min_scale_ratio: Minimum scale ratio. Used for scaling point cloud.
    max_scale_ratio: Maximum scale ratio. Used for scaling point cloud.
    semantic_labels_offset: An integer offset that will be added to labels.
    ignore_labels: A tuple containing labels that should be ignored when
      computing the loss and metrics.
    remove_unlabeled_images_and_points: If True, removes the images that are not
      labeled and also removes the points that are associated with those images.
    labeled_view_name: The name of the view that is labeled, otherwise None.
    only_keep_first_return_lidar_points: If True, we only keep the first return
      lidar points.

  Returns:
    The mean subtracted points with an optional rotation applied.

  Raises:
    ValueError: if `inputs` doesn't contain the points_key.
    ValueError: if `points_in_image_frame` does not have rank 3.
  """
    inputs = dict(inputs)

    if using_sequence_dataset:
        all_frame_inputs = inputs
        scene = all_frame_inputs['scene']
        frame1 = all_frame_inputs['frame1']
        frame_start_index = all_frame_inputs['frame_start_index']
        inputs = dict(
            all_frame_inputs['frame0']
        )  # so that the following processing code can be unchanged.

    # Initializing empty dictionary for mesh, image, indices_2d and non tensor
    # inputs.
    non_tensor_inputs = {}
    view_image_inputs = {}
    view_indices_2d_inputs = {}
    mesh_inputs = {}

    if image_preprocess_fn_dic is None:
        image_preprocess_fn_dic = {}

    # Convert all float64 to float32 and all int64 to int32.
    for key in sorted(inputs):
        if isinstance(inputs[key], tf.Tensor):
            if inputs[key].dtype == tf.float64:
                inputs[key] = tf.cast(inputs[key], dtype=tf.float32)
            if inputs[key].dtype == tf.int64:
                inputs[key] = tf.cast(inputs[key], dtype=tf.int32)

    if points_key in inputs:
        inputs[standard_fields.InputDataFields.
               point_positions] = inputs[points_key]
    if colors_key is not None and colors_key in inputs:
        inputs[
            standard_fields.InputDataFields.point_colors] = inputs[colors_key]
    if normals_key is not None and normals_key in inputs:
        inputs[standard_fields.InputDataFields.
               point_normals] = inputs[normals_key]
    if intensities_key is not None and intensities_key in inputs:
        inputs[standard_fields.InputDataFields.
               point_intensities] = inputs[intensities_key]
    if elongations_key is not None and elongations_key in inputs:
        inputs[standard_fields.InputDataFields.
               point_elongations] = inputs[elongations_key]
    if semantic_labels_key is not None and semantic_labels_key in inputs:
        inputs[standard_fields.InputDataFields.
               object_class_points] = inputs[semantic_labels_key]
    if motion_labels_key is not None and motion_labels_key in inputs:
        inputs[standard_fields.InputDataFields.
               object_flow_points] = inputs[motion_labels_key]
    if spin_coords_key is not None and spin_coords_key in inputs:
        inputs[standard_fields.InputDataFields.
               point_spin_coordinates] = inputs[spin_coords_key]

    # Acquire point / image correspondences.
    if images_points_correspondence_fn is not None:
        fn_outputs = images_points_correspondence_fn(inputs)
        if 'points_position' in fn_outputs:
            inputs[standard_fields.InputDataFields.
                   point_positions] = fn_outputs['points_position']
        if 'points_intensity' in fn_outputs and intensities_key is not None:
            inputs[standard_fields.InputDataFields.
                   point_intensities] = fn_outputs['points_intensity']
        if 'points_elongation' in fn_outputs and elongations_key is not None:
            inputs[standard_fields.InputDataFields.
                   point_elongations] = fn_outputs['points_elongation']
        if 'points_label' in fn_outputs and semantic_labels_key is not None:
            inputs[standard_fields.InputDataFields.
                   object_class_points] = fn_outputs['points_label']
        if 'view_images' in fn_outputs:
            for key in sorted(fn_outputs['view_images']):
                if len(fn_outputs['view_images'][key].shape) != 4:
                    raise ValueError(('%s image should have rank 4.' % key))
            view_image_inputs = fn_outputs['view_images']
        if 'view_indices_2d' in fn_outputs:
            for key in sorted(fn_outputs['view_indices_2d']):
                if len(fn_outputs['view_indices_2d'][key].shape) != 3:
                    raise ValueError(
                        ('%s indices_2d should have rank 3.' % key))
            view_indices_2d_inputs = fn_outputs['view_indices_2d']
    else:
        if points_in_image_frame_key is not None:
            inputs['rgb_view/features'] = inputs['image']
            inputs['rgb_view/indices_2d'] = inputs[points_in_image_frame_key]
            if len(inputs['rgb_view/indices_2d'].shape) != 3:
                raise ValueError('`points_in_image_frame` should have rank 3.')

    frame0 = inputs.copy()
    if num_frame_to_load > 1:
        point_positions_list = [
            frame0[standard_fields.InputDataFields.point_positions]
        ]
        if view_indices_2d_inputs:
            view_indices_2d_list = [view_indices_2d_inputs[view_names[0]]]
        frame_source_list = [
            tf.zeros([
                tf.shape(
                    frame0[standard_fields.InputDataFields.point_positions])[0]
            ], tf.int32)
        ]
        for i in range(1, num_frame_to_load):
            target_frame_key = 'frame' + str(i)
            if images_points_correspondence_fn is not None:
                frame_i = images_points_correspondence_fn(
                    all_frame_inputs[target_frame_key])
            else:
                raise ValueError(
                    'images_points_correspondence_fn is needed for loading multi-frame pointclouds.'
                )
            transformed_point_positions = transform_points_fn(
                scene, frame_i['points_position'], frame_start_index,
                i + frame_start_index)
            point_positions_list.append(transformed_point_positions)
            if view_indices_2d_inputs:
                view_indices_2d_list.append(
                    frame_i['view_indices_2d'][view_names[0]])
            frame_source_list.append(
                tf.ones([tf.shape(transformed_point_positions)[0]], tf.int32) *
                i)

        # add multi-frame info to override inputs and view_indices_2d_inputs
        inputs[standard_fields.InputDataFields.
               point_frame_index] = tf.expand_dims(tf.concat(frame_source_list,
                                                             axis=0),
                                                   axis=1)
        inputs[standard_fields.InputDataFields.point_positions] = tf.concat(
            point_positions_list, axis=0)
        if view_indices_2d_inputs:
            view_indices_2d_inputs[view_names[0]] = tf.concat(
                view_indices_2d_list, axis=1)

    # Validate inputs.
    if standard_fields.InputDataFields.point_positions not in inputs:
        raise ValueError('`inputs` must contain a point_positions')
    if inputs[
            standard_fields.InputDataFields.point_positions].shape.ndims != 2:
        raise ValueError('points must be of rank 2.')
    if inputs[standard_fields.InputDataFields.point_positions].shape[1] != 3:
        raise ValueError('point should be 3 dimensional.')

    # Remove normal nans.
    if standard_fields.InputDataFields.point_normals in inputs:
        inputs[standard_fields.InputDataFields.point_normals] = tf.where(
            tf.math.is_nan(
                inputs[standard_fields.InputDataFields.point_normals]),
            tf.zeros_like(
                inputs[standard_fields.InputDataFields.point_normals]),
            inputs[standard_fields.InputDataFields.point_normals])

    # Compute semantic labels if compute_semantic_labels_fn is not None
    # An example is when the ground-truth contains 3d object boxes and not per
    # point labels. This would be a function that infers point labels from boxes.
    if compute_semantic_labels_fn is not None:
        inputs[standard_fields.InputDataFields.
               object_class_points] = compute_semantic_labels_fn(
                   inputs=frame0,
                   points_key=standard_fields.InputDataFields.point_positions)
    if compute_motion_labels_fn is not None:
        inputs[standard_fields.InputDataFields.
               object_flow_points] = compute_motion_labels_fn(
                   scene=scene,
                   frame0=frame0,
                   frame1=frame1,
                   frame_start_index=frame_start_index,
                   points_key=standard_fields.InputDataFields.point_positions)

    # Splitting inputs to {view_image_inputs,
    #                      view_indices_2d_inputs,
    #                      mesh_inputs,
    #                      non_tensor_inputs}
    mesh_keys = []
    for key in [
            standard_fields.InputDataFields.point_positions,
            standard_fields.InputDataFields.point_colors,
            standard_fields.InputDataFields.point_normals,
            standard_fields.InputDataFields.point_intensities,
            standard_fields.InputDataFields.point_elongations,
            standard_fields.InputDataFields.object_class_points,
            standard_fields.InputDataFields.point_spin_coordinates,
            standard_fields.InputDataFields.object_flow_points,
            standard_fields.InputDataFields.point_frame_index,
    ]:
        if key is not None and key in inputs:
            mesh_keys.append(key)
    view_image_names = [('%s/features' % key) for key in view_names]
    view_indices_2d_names = [('%s/indices_2d' % key) for key in view_names]

    # Additional key collecting
    for k, v in six.iteritems(inputs):
        if k in view_image_names:
            view_image_inputs[k] = v
        elif k in view_indices_2d_names:
            view_indices_2d_inputs[k] = v
        elif k in mesh_keys:
            if num_frame_to_load > 1:
                pad_size = tf.shape(
                    inputs[standard_fields.InputDataFields.
                           point_positions])[0] - tf.shape(v)[0]
                if k == standard_fields.InputDataFields.object_class_points:
                    pad_value = -1
                else:
                    pad_value = 0
                v = tf.pad(v, [[0, pad_size], [0, 0]],
                           constant_values=pad_value)
            mesh_inputs[k] = v
        else:
            non_tensor_inputs[k] = v

    # Remove points that are not in the lidar first return (optional)
    if only_keep_first_return_lidar_points:
        _remove_second_return_lidar_points(
            mesh_inputs=mesh_inputs,
            view_indices_2d_inputs=view_indices_2d_inputs)

    # Randomly sample points
    preprocessor_utils.randomly_sample_points(
        mesh_inputs=mesh_inputs,
        view_indices_2d_inputs=view_indices_2d_inputs,
        target_num_points=num_points_to_randomly_sample)

    # Add weights if it does not exist in inputs. The weight of the points with
    # label in `ignore_labels` is set to 0. This helps the loss and metrics to
    # ignore those labels.
    use_weights = (
        standard_fields.InputDataFields.object_class_points in mesh_inputs
        or standard_fields.InputDataFields.object_flow_points in mesh_inputs)
    if use_weights:
        if num_frame_to_load > 1:
            num_valid_points_frame0 = tf.shape(
                frame0[standard_fields.InputDataFields.point_positions])[0]
            num_additional_frame_points = tf.shape(
                mesh_inputs[standard_fields.InputDataFields.
                            object_class_points])[0] - num_valid_points_frame0
            weights = tf.concat([
                tf.ones([num_valid_points_frame0, 1], tf.float32),
                tf.zeros([num_additional_frame_points, 1], tf.float32)
            ],
                                axis=0)
        else:
            weights = tf.ones_like(mesh_inputs[
                standard_fields.InputDataFields.object_class_points],
                                   dtype=tf.float32)

    if standard_fields.InputDataFields.object_class_points in mesh_inputs:
        mesh_inputs[
            standard_fields.InputDataFields.object_class_points] = tf.cast(
                mesh_inputs[
                    standard_fields.InputDataFields.object_class_points],
                dtype=tf.int32)
        for ignore_label in ignore_labels:
            weights *= tf.cast(tf.not_equal(
                mesh_inputs[
                    standard_fields.InputDataFields.object_class_points],
                ignore_label),
                               dtype=tf.float32)
        mesh_inputs[
            standard_fields.InputDataFields.point_loss_weights] = weights
        mesh_inputs[standard_fields.InputDataFields.
                    object_class_points] += semantic_labels_offset

    # We normalize the intensities and elongations to be in a smaller range.
    if standard_fields.InputDataFields.point_intensities in mesh_inputs:
        mesh_inputs[standard_fields.InputDataFields.
                    point_intensities] = change_intensity_range(
                        intensities=mesh_inputs[
                            standard_fields.InputDataFields.point_intensities])
    if standard_fields.InputDataFields.point_elongations in mesh_inputs:
        mesh_inputs[
            standard_fields.InputDataFields.point_elongations] = (tf.cast(
                mesh_inputs[standard_fields.InputDataFields.point_elongations],
                dtype=tf.float32) * 2.0 / 255.0) - 1.0

    # Random scale the points.
    if min_scale_ratio is not None and max_scale_ratio is not None:
        scale_ratio = tf.random.uniform([],
                                        minval=min_scale_ratio,
                                        maxval=max_scale_ratio,
                                        dtype=tf.float32)
        mesh_inputs[
            standard_fields.InputDataFields.point_positions] *= scale_ratio
        if standard_fields.InputDataFields.object_flow_points in mesh_inputs:
            mesh_inputs[standard_fields.InputDataFields.
                        object_flow_points] *= scale_ratio

    # Random crop the points.
    randomly_crop_points(mesh_inputs=mesh_inputs,
                         view_indices_2d_inputs=view_indices_2d_inputs,
                         x_random_crop_size=x_random_crop_size,
                         y_random_crop_size=y_random_crop_size)

    # If training, pick the best labeled image and points that project to it.
    # In many datasets, only one image is labeled anyways.
    if remove_unlabeled_images_and_points:
        pick_labeled_image(mesh_inputs=mesh_inputs,
                           view_image_inputs=view_image_inputs,
                           view_indices_2d_inputs=view_indices_2d_inputs,
                           view_name=labeled_view_name)

    # Process images.
    preprocessor_utils.preprocess_images(
        view_image_inputs=view_image_inputs,
        view_indices_2d_inputs=view_indices_2d_inputs,
        image_preprocess_fn_dic=image_preprocess_fn_dic,
        is_training=is_training)

    # Record the original points.
    original_points = mesh_inputs[
        standard_fields.InputDataFields.point_positions]
    if standard_fields.InputDataFields.point_colors in mesh_inputs:
        original_colors = mesh_inputs[
            standard_fields.InputDataFields.point_colors]
    if standard_fields.InputDataFields.point_normals in mesh_inputs:
        original_normals = mesh_inputs[
            standard_fields.InputDataFields.point_normals]

    # Update feature visibility count.
    if 'feature_visibility_count' in mesh_inputs:
        mesh_inputs['feature_visibility_count'] = tf.maximum(
            mesh_inputs['feature_visibility_count'], 1)
        mesh_inputs['features'] /= tf.cast(
            mesh_inputs['feature_visibility_count'], dtype=tf.float32)

    # Subtract mean from points.
    mean_points = tf.reduce_mean(
        mesh_inputs[standard_fields.InputDataFields.point_positions], axis=0)
    mesh_inputs[
        standard_fields.InputDataFields.point_positions] -= tf.expand_dims(
            mean_points, axis=0)

    # Rotate points randomly.
    if standard_fields.InputDataFields.point_normals in mesh_inputs:
        normals = mesh_inputs[standard_fields.InputDataFields.point_normals]
    else:
        normals = None

    if standard_fields.InputDataFields.object_flow_points in mesh_inputs:
        motions = mesh_inputs[
            standard_fields.InputDataFields.object_flow_points]
    else:
        motions = None

    (mesh_inputs[standard_fields.InputDataFields.point_positions],
     rotated_normals, rotated_motions) = rotate_randomly(
         points=mesh_inputs[standard_fields.InputDataFields.point_positions],
         normals=normals,
         motions=motions,
         x_min_degree_rotation=x_min_degree_rotation,
         x_max_degree_rotation=x_max_degree_rotation,
         y_min_degree_rotation=y_min_degree_rotation,
         y_max_degree_rotation=y_max_degree_rotation,
         z_min_degree_rotation=z_min_degree_rotation,
         z_max_degree_rotation=z_max_degree_rotation)

    # Random flipping in x and y directions.
    (mesh_inputs[standard_fields.InputDataFields.point_positions],
     flipped_normals,
     flipped_motions) = flip_randomly_points_and_normals_motions(
         points=mesh_inputs[standard_fields.InputDataFields.point_positions],
         normals=rotated_normals,
         motions=rotated_motions,
         is_training=is_training)
    if standard_fields.InputDataFields.point_normals in mesh_inputs:
        mesh_inputs[
            standard_fields.InputDataFields.point_normals] = flipped_normals
    if standard_fields.InputDataFields.object_flow_points in mesh_inputs:
        mesh_inputs[standard_fields.InputDataFields.
                    object_flow_points] = flipped_motions
    # Normalize RGB to [-1.0, 1.0].
    if standard_fields.InputDataFields.point_colors in mesh_inputs:
        mesh_inputs[standard_fields.InputDataFields.point_colors] = tf.cast(
            mesh_inputs[standard_fields.InputDataFields.point_colors],
            dtype=tf.float32)
        mesh_inputs[standard_fields.InputDataFields.point_colors] *= (2.0 /
                                                                      255.0)
        mesh_inputs[standard_fields.InputDataFields.point_colors] -= 1.0

    # Add original points to mesh inputs.
    mesh_inputs[standard_fields.InputDataFields.
                point_positions_original] = original_points
    if standard_fields.InputDataFields.point_colors in mesh_inputs:
        mesh_inputs[standard_fields.InputDataFields.
                    point_colors_original] = original_colors
    if standard_fields.InputDataFields.point_normals in mesh_inputs:
        mesh_inputs[standard_fields.InputDataFields.
                    point_normals_original] = original_normals

    # Pad or clip the point tensors.
    pad_or_clip(mesh_inputs=mesh_inputs,
                view_indices_2d_inputs=view_indices_2d_inputs,
                pad_or_clip_size=points_pad_or_clip_size)
    if num_frame_to_load > 1:
        # Note: num_valid_points is the sum of 'num_points_per_fram' for now.
        # num_points_per_frame is each frame's valid num of points.
        # TODO(huangrui): if random sampling is called earlier, the count here
        # is not guaranteed to be in order. need sorting.
        if num_points_to_randomly_sample is not None:
            raise ValueError(
                'randomly sample is not compatible with padding multi frame point clouds yet!'
            )
        _, _, mesh_inputs[standard_fields.InputDataFields.
                          num_valid_points_per_frame] = tf.unique_with_counts(
                              tf.reshape(
                                  mesh_inputs[standard_fields.InputDataFields.
                                              point_frame_index], [-1]))
        if points_pad_or_clip_size is not None:
            padded_points = tf.where_v2(
                tf.greater(
                    points_pad_or_clip_size, mesh_inputs[
                        standard_fields.InputDataFields.num_valid_points]),
                points_pad_or_clip_size -
                mesh_inputs[standard_fields.InputDataFields.num_valid_points],
                0)

            # Correct the potential unique count error from optionally padded 0s point
            # frame index.
            mesh_inputs[
                standard_fields.InputDataFields.
                num_valid_points_per_frame] -= tf.pad(
                    tf.expand_dims(padded_points, 0), [[
                        0,
                        tf.shape(mesh_inputs[standard_fields.InputDataFields.
                                             num_valid_points_per_frame])[0] -
                        1
                    ]])

    # Putting back the dictionaries together
    processed_inputs = mesh_inputs.copy()
    processed_inputs.update(non_tensor_inputs)
    for key in sorted(view_image_inputs):
        processed_inputs[('%s/features' % key)] = view_image_inputs[key]
    for key in sorted(view_indices_2d_inputs):
        processed_inputs[('%s/indices_2d' % key)] = view_indices_2d_inputs[key]

    # Create features that do not exist
    if 'point_offsets' in point_feature_keys:
        preprocessor_utils.add_point_offsets(
            inputs=processed_inputs, voxel_grid_cell_size=voxel_grid_cell_size)
    if 'point_offset_bins' in point_feature_keys:
        preprocessor_utils.add_point_offset_bins(
            inputs=processed_inputs,
            voxel_grid_cell_size=voxel_grid_cell_size,
            num_bins_x=num_offset_bins_x,
            num_bins_y=num_offset_bins_y,
            num_bins_z=num_offset_bins_z)

    # Voxelize point features
    preprocessor_utils.voxelize_point_features(
        inputs=processed_inputs,
        voxels_pad_or_clip_size=voxels_pad_or_clip_size,
        voxel_grid_cell_size=voxel_grid_cell_size,
        point_feature_keys=point_feature_keys,
        point_to_voxel_segment_func=point_to_voxel_segment_func,
        num_frame_to_load=num_frame_to_load)

    # Voxelize point / image correspondence indices
    preprocessor_utils.voxelize_point_to_view_correspondences(
        inputs=processed_inputs,
        view_indices_2d_inputs=view_indices_2d_inputs,
        voxels_pad_or_clip_size=voxels_pad_or_clip_size,
        voxel_grid_cell_size=voxel_grid_cell_size)

    # Voxelizing the semantic labels
    preprocessor_utils.voxelize_semantic_labels(
        inputs=processed_inputs,
        voxels_pad_or_clip_size=voxels_pad_or_clip_size,
        voxel_grid_cell_size=voxel_grid_cell_size)

    # Voxelizing the loss weights
    preprocessor_utils.voxelize_property_tensor(
        inputs=processed_inputs,
        point_tensor_key=standard_fields.InputDataFields.point_loss_weights,
        corresponding_voxel_tensor_key=standard_fields.InputDataFields.
        voxel_loss_weights,
        voxels_pad_or_clip_size=voxels_pad_or_clip_size,
        voxel_grid_cell_size=voxel_grid_cell_size,
        segment_func=tf.math.unsorted_segment_max)

    # Voxelizing the object flow
    if standard_fields.InputDataFields.object_flow_points in processed_inputs:
        preprocessor_utils.voxelize_property_tensor(
            inputs=processed_inputs,
            point_tensor_key=standard_fields.InputDataFields.
            object_flow_points,
            corresponding_voxel_tensor_key='object_flow_voxels_max',
            voxels_pad_or_clip_size=voxels_pad_or_clip_size,
            voxel_grid_cell_size=voxel_grid_cell_size,
            segment_func=tf.math.unsorted_segment_max)
        preprocessor_utils.voxelize_property_tensor(
            inputs=processed_inputs,
            point_tensor_key=standard_fields.InputDataFields.
            object_flow_points,
            corresponding_voxel_tensor_key='object_flow_voxels_min',
            voxels_pad_or_clip_size=voxels_pad_or_clip_size,
            voxel_grid_cell_size=voxel_grid_cell_size,
            segment_func=tf.math.unsorted_segment_min)
        processed_inputs[standard_fields.InputDataFields.
                         object_flow_voxels] = processed_inputs[
                             'object_flow_voxels_max'] + processed_inputs[
                                 'object_flow_voxels_min']

    if num_frame_to_load > 1:
        mesh_inputs[
            standard_fields.InputDataFields.num_valid_points] = mesh_inputs[
                standard_fields.InputDataFields.num_valid_points_per_frame][0]

    # Filter preprocessed_inputs by output_keys if it is not None.
    if output_keys is not None:
        processed_inputs = {
            k: v
            for k, v in six.iteritems(processed_inputs) if k in output_keys
        }
    return processed_inputs
Beispiel #17
0
    def __call__(self, example_string):
        """Processes a single example string.

    Extracts and processes the image, and ignores the label. We assume that the
    image has three channels.

    Args:
      example_string: str, an Example protocol buffer.

    Returns:
      image_rescaled: the image, resized to `image_size x image_size` and
      rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values
      to go beyond this range.
    """
        image_string = tf.parse_single_example(
            example_string,
            features={
                'image': tf.FixedLenFeature([], dtype=tf.string),
                'label': tf.FixedLenFeature([], tf.int64)
            })['image']
        image_decoded = tf.image.decode_image(image_string, channels=3)
        image_decoded.set_shape([None, None, 3])
        image_resized = tf.image.resize_images(
            image_decoded, [self.image_size, self.image_size],
            method=tf.image.ResizeMethod.BILINEAR,
            align_corners=True)
        image = tf.cast(image_resized, tf.float32)

        if self.data_augmentation is not None:
            if self.data_augmentation.enable_random_brightness:
                delta = self.data_augmentation.random_brightness_delta
                image = tf.image.random_brightness(image, delta)

            if self.data_augmentation.enable_random_saturation:
                delta = self.data_augmentation.random_saturation_delta
                image = tf.image.random_saturation(image, 1 - delta, 1 + delta)

            if self.data_augmentation.enable_random_contrast:
                delta = self.data_augmentation.random_contrast_delta
                image = tf.image.random_contrast(image, 1 - delta, 1 + delta)

            if self.data_augmentation.enable_random_hue:
                delta = self.data_augmentation.random_hue_delta
                image = tf.image.random_hue(image, delta)

            if self.data_augmentation.enable_random_flip:
                image = tf.image.random_flip_left_right(image)

        image = 2 * (image / 255.0 - 0.5)  # Rescale to [-1, 1].

        if self.data_augmentation is not None:
            if self.data_augmentation.enable_gaussian_noise:
                image = image + tf.random_normal(tf.shape(
                    image)) * self.data_augmentation.gaussian_noise_std

            if self.data_augmentation.enable_jitter:
                j = self.data_augmentation.jitter_amount
                paddings = tf.constant([[j, j], [j, j], [0, 0]])
                image = tf.pad(image, paddings, 'REFLECT')
                image = tf.image.random_crop(
                    image, [self.image_size, self.image_size, 3])

        return image
def geometric_augmentation(images,
                           flow = None,
                           mask = None,
                           crop_height = 640,
                           crop_width = 640,
                           probability_flip_left_right = 0.5,
                           probability_flip_up_down = 0.1,
                           probability_scale = 0.8,
                           probability_relative_scale = 0.,
                           probability_stretch = 0.8,
                           probability_rotation = 0.0,
                           probability_relative_rotation = 0.0,
                           probability_crop_offset = 0.0,
                           min_bound_scale = -0.2,
                           max_bound_scale = 0.6,
                           max_strech_scale = 0.2,
                           min_bound_relative_scale = -0.1,
                           max_bound_relative_scale = 0.1,
                           max_rotation_deg = 15,
                           max_relative_rotation_deg = 3,
                           max_relative_crop_offset = 5,
                           return_full_scale=False):

  """Applies geometric augmentations to an image pair and corresponding flow.

  Args:
    images: Image pair of shape [2, height, width, channels].
    flow: Corresponding forward flow field of shape [height, width, 2].
    mask: Mask indicating which positions in the flow field hold valid flow
      vectors of shape [height, width, 1]. Non-valid poisitions are encoded with
      0, valid positions with 1.
    crop_height: Height of the final augmented output.
    crop_width: Width of the final augmented output.
    probability_flip_left_right: Probability of applying left/right flip.
    probability_flip_up_down: Probability of applying up/down flip
    probability_scale: Probability of applying scale augmentation.
    probability_relative_scale: Probability of applying scale augmentation to
      only the second frame of the the image pair.
    probability_stretch: Probability of applying stretch augmentation (scale
      without keeping the aspect ratio).
    probability_rotation: Probability of applying rotation augmentation.
    probability_relative_rotation: Probability of applying rotation augmentation
      to only the second frame of the the image pair.
    probability_crop_offset: Probability of applying a relative offset while
      cropping.
    min_bound_scale: Defines the smallest possible scaling factor as
      2**min_bound_scale.
    max_bound_scale: Defines the largest possible scaling factor as
      2**max_bound_scale.
    max_strech_scale: Defines the smallest and largest possible streching factor
      as 2**-max_strech_scale and 2**max_strech_scale.
    min_bound_relative_scale: Defines the smallest possible scaling factor for
      the relative scaling as 2**min_bound_relative_scale.
    max_bound_relative_scale: Defines the largest possible scaling factor for
      the relative scaling as 2**max_bound_relative_scale.
    max_rotation_deg: Defines the maximum angle of rotation in degrees.
    max_relative_rotation_deg: Defines the maximum angle of rotation in degrees
      for the relative rotation.
    max_relative_crop_offset: Defines the maximum relative offset in pixels for
      cropping.
    return_full_scale: bool. If this is passed, the full size images will be
      returned in addition to the geometrically augmented (cropped and / or
      resized) images. In addition to the resized images, the crop height,
      width, and any padding applied will be returned.

  Returns:
    if return_full_scale is False:
      Augmented images, flow and mask (if not None).
    if return_full_scale is True:
      Augmented images, flow, mask, full_size_images, crop_h, crop_w, pad_h,
       and pad_w.
  """

  # apply geometric augmentation
  if probability_flip_left_right > 0:
    images, flow, mask = random_flip_left_right(
        images, flow, mask, probability_flip_left_right)

  if probability_flip_up_down > 0:
    images, flow, mask = random_flip_up_down(
        images, flow, mask, probability_flip_up_down)

  if probability_scale > 0 or probability_stretch > 0:
    images, flow, mask = random_scale(
        images,
        flow,
        mask,
        min_scale=min_bound_scale,
        max_scale=max_bound_scale,
        max_strech=max_strech_scale,
        probability_scale=probability_scale,
        probability_strech=probability_stretch)

  if probability_relative_scale > 0:
    images, flow, mask = random_scale_second(
        images, flow, mask,
        min_scale=min_bound_relative_scale,
        max_scale=max_bound_relative_scale,
        probability_scale=probability_relative_scale)

  if probability_rotation > 0:
    images, flow, mask = random_rotation(
        images, flow, mask,
        probability=probability_rotation,
        max_rotation=max_rotation_deg, not_empty_crop=True)

  if probability_relative_rotation > 0:
    images, flow, mask = random_rotation_second(
        images, flow, mask,
        probability=probability_relative_rotation,
        max_rotation=max_relative_rotation_deg, not_empty_crop=True)

  images_uncropped = images
  images, flow, mask, offset_h, offset_w = random_crop(
      images, flow, mask, crop_height, crop_width,
      relative_offset=max_relative_crop_offset,
      probability_crop_offset=probability_crop_offset)
  # Add 100 / 200 pixels to crop height / width for full scale warp
  pad_to_size_h = crop_height + 200
  pad_to_size_w = crop_width + 400
  if return_full_scale:
    if pad_to_size_w:
      uncropped_shape = tf.shape(images_uncropped)
      if images.shape[1] > uncropped_shape[1] or images.shape[
          2] > uncropped_shape[2]:
        images_uncropped = images
        uncropped_shape = tf.shape(images_uncropped)
        offset_h = tf.zeros_like(offset_h)
        offset_w = tf.zeros_like(offset_w)

      if uncropped_shape[1] > pad_to_size_h:
        crop_ht = offset_h - (200 // 2)
        crop_hb = offset_h + crop_height + (200 // 2)
        crop_hb += tf.maximum(0, -crop_ht)
        crop_ht -= tf.maximum(0, -(uncropped_shape[1] - crop_hb))
        crop_ht = tf.maximum(crop_ht, 0)
        crop_hb = tf.minimum(crop_hb, uncropped_shape[1])
        offset_h -= crop_ht
        images_uncropped = images_uncropped[:, crop_ht:crop_hb, :, :]

      if uncropped_shape[2] > pad_to_size_w:
        crop_wt = offset_w - (400 // 2)
        crop_wb = offset_w + crop_width + (400 // 2)
        crop_wb += tf.maximum(0, -crop_wt)
        crop_wt -= tf.maximum(0, -(uncropped_shape[2] - crop_wb))
        crop_wt = tf.maximum(crop_wt, 0)
        crop_wb = tf.minimum(crop_wb, uncropped_shape[2])
        offset_w -= crop_wt
        images_uncropped = images_uncropped[:, :, crop_wt:crop_wb, :]

      uncropped_shape = tf.shape(images_uncropped)
      # remove remove_pixels_w from the width while keeping the crop centered
      pad_h = pad_to_size_h - uncropped_shape[1]
      pad_w = pad_to_size_w - uncropped_shape[2]
      with tf.control_dependencies([
          tf.compat.v1.assert_greater_equal(pad_h, 0),
          tf.compat.v1.assert_greater_equal(pad_w, 0)
      ]):
        images_uncropped = tf.pad(images_uncropped,
                                  [[0, 0], [pad_h, 0], [pad_w, 0], [0, 0]])
      images_uncropped = tf.ensure_shape(images_uncropped,
                                         [2, pad_to_size_h, pad_to_size_w, 3])
    return images, flow, mask, images_uncropped, offset_h, offset_w, pad_h, pad_w

  return images, flow, mask
Beispiel #19
0
def train_eval(
        load_root_dir,
        env_load_fn=None,
        gym_env_wrappers=[],
        monitor=False,
        env_name=None,
        agent_class=None,
        train_metrics_callback=None,
        # SacAgent args
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256),
        # Safety Critic training args
        safety_critic_joint_fc_layers=None,
        safety_critic_lr=3e-4,
        safety_critic_bias_init_val=None,
        safety_critic_kernel_scale=None,
        n_envs=None,
        target_safety=0.2,
        fail_weight=None,
        # Params for train
        num_global_steps=10000,
        batch_size=256,
        # Params for eval
        run_eval=False,
        eval_metrics=[],
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        summary_interval=1000,
        monitor_interval=5000,
        summaries_flush_secs=10,
        debug_summaries=False,
        seed=None):

    if isinstance(agent_class, str):
        assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(
            agent_class)
        agent_class = ALGOS.get(agent_class)

    train_ckpt_dir = osp.join(load_root_dir, 'train')
    rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer')

    py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)

    if monitor:
        vid_path = os.path.join(load_root_dir, 'rollouts')
        monitor_env_wrapper = misc.monitor_freq(1, vid_path)
        monitor_env = gym.make(env_name)
        for wrapper in gym_env_wrappers:
            monitor_env = wrapper(monitor_env)
        monitor_env = monitor_env_wrapper(monitor_env)
        # auto_reset must be False to ensure Monitor works correctly
        monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

    if run_eval:
        eval_dir = os.path.join(load_root_dir, 'eval')
        n_envs = n_envs or num_eval_episodes
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(prefix='EvalMetrics',
                                           buffer_size=num_eval_episodes,
                                           batch_size=n_envs),
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='EvalMetrics',
                buffer_size=num_eval_episodes,
                batch_size=n_envs)
        ] + [
            tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name))
            for m in eval_metrics
        ]
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(env_name,
                                    gym_env_wrappers=gym_env_wrappers)
            ] * n_envs))
        if seed:
            seeds = [seed * n_envs + i for i in range(n_envs)]
            try:
                eval_tf_env.pyenv.seed(seeds)
            except:
                pass

    global_step = tf.compat.v1.train.get_or_create_global_step()

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=agents.normal_projection_net)

    critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)

    if agent_class in SAFETY_AGENTS:
        safety_critic_net = agents.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=critic_joint_fc_layers)
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               safety_critic_network=safety_critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)
    else:
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)

    collect_data_spec = tf_agent.collect_data_spec
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, batch_size=1, max_length=1000000)
    replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer)

    tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent)
    if agent_class in SAFETY_AGENTS:
        target_safety = target_safety or tf_agent._target_safety
    loaded_train_steps = global_step.numpy()
    logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir,
                 loaded_train_steps)
    global_step.assign(0)
    tf.summary.experimental.set_step(global_step)

    thresholds = [target_safety, 0.5]
    sc_metrics = [
        tf.keras.metrics.AUC(name='safety_critic_auc'),
        tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc',
                                        threshold=0.5),
        tf.keras.metrics.TruePositives(name='safety_critic_tp',
                                       thresholds=thresholds),
        tf.keras.metrics.FalsePositives(name='safety_critic_fp',
                                        thresholds=thresholds),
        tf.keras.metrics.TrueNegatives(name='safety_critic_tn',
                                       thresholds=thresholds),
        tf.keras.metrics.FalseNegatives(name='safety_critic_fn',
                                        thresholds=thresholds)
    ]

    if seed:
        tf.compat.v1.set_random_seed(seed)

    summaries_flush_secs = 10
    timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S')
    offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp)
    config_saver = gin.tf.GinConfigSaverHook(offline_train_dir,
                                             summarize_config=True)
    tf.function(config_saver.after_create_session)()

    sc_summary_writer = tf.compat.v2.summary.create_file_writer(
        offline_train_dir, flush_millis=summaries_flush_secs * 1000)
    sc_summary_writer.set_as_default()

    if safety_critic_kernel_scale is not None:
        ki = tf.compat.v1.variance_scaling_initializer(
            scale=safety_critic_kernel_scale,
            mode='fan_in',
            distribution='truncated_normal')
    else:
        ki = tf.compat.v1.keras.initializers.VarianceScaling(
            scale=1. / 3., mode='fan_in', distribution='uniform')

    if safety_critic_bias_init_val is not None:
        bi = tf.constant_initializer(safety_critic_bias_init_val)
    else:
        bi = None
    sc_net_off = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=safety_critic_joint_fc_layers,
        kernel_initializer=ki,
        value_bias_initializer=bi,
        name='SafetyCriticOffline')
    sc_net_off.create_variables()
    target_sc_net_off = common.maybe_copy_target_network_with_checks(
        sc_net_off, None, 'TargetSafetyCriticNetwork')
    optimizer = tf.keras.optimizers.Adam(safety_critic_lr)
    sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic')
    sc_checkpointer = common.Checkpointer(
        ckpt_dir=sc_net_off_ckpt_dir,
        safety_critic=sc_net_off,
        target_safety_critic=target_sc_net_off,
        optimizer=optimizer,
        global_step=global_step,
        max_to_keep=5)
    sc_checkpointer.initialize_or_restore()

    resample_counter = py_metrics.CounterMetric('ActionResampleCounter')
    eval_policy = agents.SafeActorPolicyRSVar(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=actor_net,
        safety_critic_network=sc_net_off,
        safety_threshold=target_safety,
        resample_counter=resample_counter,
        training=True)

    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       num_steps=2,
                                       sample_batch_size=batch_size //
                                       2).prefetch(3)
    data = iter(dataset)
    full_data = replay_buffer.gather_all()

    fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool)
    fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, fail_mask), full_data)
    init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data)
    before_fail_mask = tf.roll(fail_mask, [-1], axis=[1])
    after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1])
    before_fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data)
    after_init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, after_init_mask), full_data)

    filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask))
    filter_mask = tf.pad(
        filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]])
    n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy()

    failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec,
        batch_size=1,
        max_length=n_failures,
        dataset_window_shift=1)
    data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask)

    sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3,
                                               sample_batch_size=batch_size //
                                               2,
                                               num_steps=2).prefetch(3)
    neg_data = iter(sc_dataset_neg)

    get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0]
    eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step,
                                after_init_step, get_action)

    losses = []
    mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss')
    target_update = train_utils.get_target_updater(sc_net_off,
                                                   target_sc_net_off)

    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        while global_step.numpy() < num_global_steps:
            pos_experience, _ = next(data)
            neg_experience, _ = next(neg_data)
            exp = data_utils.concat_batches(pos_experience, neg_experience,
                                            collect_data_spec)
            boundary_mask = tf.logical_not(exp.is_boundary()[:, 0])
            exp = nest_utils.fast_map_structure(
                lambda *x: tf.boolean_mask(*x, boundary_mask), exp)
            safe_rew = exp.observation['task_agn_rew'][:, 1]
            if fail_weight:
                weights = tf.where(tf.cast(safe_rew, tf.bool),
                                   fail_weight / 0.5, (1 - fail_weight) / 0.5)
            else:
                weights = None
            train_loss, sc_loss, lam_loss = train_step(
                exp,
                safe_rew,
                tf_agent,
                sc_net=sc_net_off,
                target_sc_net=target_sc_net_off,
                metrics=sc_metrics,
                weights=weights,
                target_safety=target_safety,
                optimizer=optimizer,
                target_update=target_update,
                debug_summaries=debug_summaries)
            global_step.assign_add(1)
            global_step_val = global_step.numpy()
            losses.append(
                (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy()))
            mean_loss(train_loss)
            with tf.name_scope('Losses'):
                tf.compat.v2.summary.scalar(name='sc_loss',
                                            data=sc_loss,
                                            step=global_step_val)
                tf.compat.v2.summary.scalar(name='lam_loss',
                                            data=lam_loss,
                                            step=global_step_val)
                if global_step_val % summary_interval == 0:
                    tf.compat.v2.summary.scalar(name=mean_loss.name,
                                                data=mean_loss.result(),
                                                step=global_step_val)
            if global_step_val % summary_interval == 0:
                with tf.name_scope('Metrics'):
                    for metric in sc_metrics:
                        if len(tf.squeeze(metric.result()).shape) == 0:
                            tf.compat.v2.summary.scalar(name=metric.name,
                                                        data=metric.result(),
                                                        step=global_step_val)
                        else:
                            fmt_str = '_{}'.format(thresholds[0])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[0],
                                step=global_step_val)
                            fmt_str = '_{}'.format(thresholds[1])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[1],
                                step=global_step_val)
                        metric.reset_states()
            if global_step_val % eval_interval == 0:
                eval_sc(sc_net_off, step=global_step_val)
                if run_eval:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix='EvalMetrics',
                    )
                    if train_metrics_callback is not None:
                        train_metrics_callback(results, global_step_val)
                    metric_utils.log_metrics(eval_metrics)
                    with eval_summary_writer.as_default():
                        for eval_metric in eval_metrics[2:]:
                            eval_metric.tf_summaries(
                                train_step=global_step,
                                step_metrics=eval_metrics[:2])
            if monitor and global_step_val % monitor_interval == 0:
                monitor_time_step = monitor_py_env.reset()
                monitor_policy_state = eval_policy.get_initial_state(1)
                ep_len = 0
                monitor_start = time.time()
                while not monitor_time_step.is_last():
                    monitor_action = eval_policy.action(
                        monitor_time_step, monitor_policy_state)
                    action, monitor_policy_state = monitor_action.action, monitor_action.state
                    monitor_time_step = monitor_py_env.step(action)
                    ep_len += 1
                logging.debug(
                    'saved rollout at timestep %d, rollout length: %d, %4.2f sec',
                    global_step_val, ep_len,
                    time.time() - monitor_start)

            if global_step_val % train_checkpoint_interval == 0:
                sc_checkpointer.save(global_step=global_step_val)