def proto_maml_fc_layer_init_fn(labels, embeddings, weights, biases, prototype_multiplier): """Return a list of operations for reparameterized ProtoNet initialization.""" # This is robust to classes missing from the training set, but assumes that # the last class is present. num_ways = tf.cast( tf.math.reduce_max(input_tensor=tf.unique(labels)[0]) + 1, tf.int32) # When there are no examples for a given class, we default its prototype to # zeros, per the implementation of `tf.math.unsorted_segment_mean`. prototypes = tf.math.unsorted_segment_mean(embeddings, labels, num_ways) # Scale the prototypes, which acts as a regularizer on the weights and biases. prototypes *= prototype_multiplier # logit = -<squared Euclidian distance to prototype> # = -(x - p)^T.(x - p) # = 2 x^T.p - p^T.p - x^T.x # = x^T.w + b # where w = 2p, b = -p^T.p output_weights = tf.transpose(a=2 * prototypes) output_biases = -tf.reduce_sum(input_tensor=prototypes * prototypes, axis=1) # We zero-pad to align with the original weights and biases. output_weights = tf.pad(tensor=output_weights, paddings=[[0, 0], [ 0, tf.shape(input=weights)[1] - tf.shape(input=output_weights)[1] ]], mode='CONSTANT', constant_values=0) output_biases = tf.pad(tensor=output_biases, paddings=[[ 0, tf.shape(input=biases)[0] - tf.shape(input=output_biases)[0] ]], mode='CONSTANT', constant_values=0) return [ weights.assign(output_weights), biases.assign(output_biases), ]
def conv1d(x, filters, kernel_size, strides=1, padding='causal', dilation_rate=1, act=None, init=None, scope="conv1d", use_bias=True): batch_size, seq_len, h = x.get_shape().as_list() # Taken from keras, there is a faster version from magenta with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): # assert seq_len % dilation_rate == 0 w = tf.get_variable('kernel', shape=(kernel_size, h, filters), dtype=tf.float32, initializer=init) if padding == 'causal': # causal (dilated) convolution: left_pad = dilation_rate * (kernel_size - 1) pattern = [[0, 0], [left_pad, 0], [0, 0]] x = tf.pad(x, pattern) padding = 'VALID' out = tf.nn.convolution( input=x, filter=w, dilation_rate=(dilation_rate,), strides=(strides,), padding=padding) if use_bias: b = tf.get_variable('bias', shape=(filters,), dtype=tf.float32, initializer=tf.initializers.zeros) out = tf.add(out, b) if act is not None: return act(out) return out
def __call__(self, example_string): """Processes a single example string. Extracts and processes the image, and ignores the label. We assume that the image has three channels. Args: example_string: str, an Example protocol buffer. Returns: image_rescaled: the image, resized to `image_size x image_size` and rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values to go beyond this range. """ image_decoded = read_example_and_parse_image(example_string)['image'] image_resized = tf.image.resize_images( image_decoded, [self.image_size, self.image_size], method=tf.image.ResizeMethod.BILINEAR, align_corners=True) image_resized = tf.cast(image_resized, tf.float32) image = 2 * (image_resized / 255.0 - 0.5) # Rescale to [-1, 1]. if self.data_augmentation is not None: if self.data_augmentation.enable_gaussian_noise: image = image + tf.random_normal(tf.shape( image)) * self.data_augmentation.gaussian_noise_std if self.data_augmentation.enable_jitter: j = self.data_augmentation.jitter_amount paddings = tf.constant([[j, j], [j, j], [0, 0]]) image = tf.pad(image, paddings, 'REFLECT') image = tf.image.random_crop( image, [self.image_size, self.image_size, 3]) return image
def fn_knn_graph_from_points_unbatched(i): """Computes knn graph for example i in the batch.""" num_valid_points_i = num_valid_points[i] points_i = points[i, :num_valid_points_i, :] if mask is None: mask_i = None else: mask_i = mask[i, :num_valid_points_i] distances_i, indices_i = knn_graph_from_points_unbatched( points=points_i, k=k, distance_upper_bound=distance_upper_bound, mask=mask_i) distances_i = tf.pad( distances_i, paddings=[[0, num_points - num_valid_points_i], [0, 0]]) indices_i = tf.pad( indices_i, paddings=[[0, num_points - num_valid_points_i], [0, 0]]) return distances_i, indices_i
def per_voxel_point_sample_segment_func(data, segment_ids, num_segments, num_samples_per_voxel): """Samples features from the points within each voxel. Args: data: A tf.float32 tensor of size [N, F]. segment_ids: A tf.int32 tensor of size [N]. num_segments: Number of segments. num_samples_per_voxel: Number of features to sample per voxel. If the voxel has less number of points in it, the point features will be padded by 0. Returns: A tf.float32 tensor of size [num_segments, num_samples_per_voxel, F]. A tf.int32 indices of size [N, num_samples_per_voxel]. """ num_channels = data.get_shape().as_list()[1] if num_channels is None: raise ValueError('num_channels is None.') n = tf.shape(segment_ids)[0] def _body_fn(i, indices_range, indices): """Computes the indices of the i-th point feature in each segment.""" indices_i = tf.math.unsorted_segment_max(data=indices_range, segment_ids=segment_ids, num_segments=num_segments) indices_i_positive_mask = tf.greater(indices_i, 0) indices_i_positive = tf.boolean_mask(indices_i, indices_i_positive_mask) boolean_mask = tf.scatter_nd(indices=tf.cast(tf.expand_dims( indices_i_positive - 1, axis=1), dtype=tf.int64), updates=tf.ones_like(indices_i_positive, dtype=tf.int32), shape=(n, )) indices_range *= (1 - boolean_mask) indices_i *= tf.cast(indices_i_positive_mask, dtype=tf.int32) indices_i = tf.pad(tf.expand_dims(indices_i, axis=1), paddings=[[0, 0], [i, num_samples_per_voxel - i - 1]]) indices += indices_i i = i + 1 return i, indices_range, indices cond = lambda i, indices_range, indices: i < num_samples_per_voxel (_, _, indices) = tf.while_loop( cond=cond, body=_body_fn, loop_vars=(tf.constant(0, dtype=tf.int32), tf.range(n) + 1, tf.zeros([num_segments, num_samples_per_voxel], dtype=tf.int32))) data = tf.pad(data, paddings=[[1, 0], [0, 0]]) voxel_features = tf.gather(data, tf.reshape(indices, [-1])) return tf.reshape(voxel_features, [num_segments, num_samples_per_voxel, num_channels])
def fn(i): num_valid_voxels_i = num_valid_voxels[i] num_valid_points_i = num_valid_points[i] voxel_features_i = voxel_features[i, :num_valid_voxels_i, :] segment_ids_i = segment_ids[i, :num_valid_points_i] point_features = tf.gather(voxel_features_i, segment_ids_i) point_features_rank = len(point_features.shape_as_list()) point_features_paddings = [[0, num_points - num_valid_points_i]] for _ in range(point_features_rank - 1): point_features_paddings.append([0, 0]) point_features = tf.pad(point_features, paddings=point_features_paddings) return point_features
def crop_and_pad_voxels(voxels, start_coordinates, end_coordinates): """Crops a voxel region and pads past the boundaries with zeros. This accepts start and end coordinates past the limits of the voxel grid, and uses it to calculate how much top/left/right/bottom padding to add. Args: voxels: A tf.float32 tensor of shape [x, y, z, f] to crop start_coordinates: A list of len 4 with the [x, y, z, f] starting location of our crop. This can be negative, which indicates left/top padding. end_coordinates: A list of len 4 with the [x, y, z, f] ending location of our crop. This can be beyond the size of the voxel tensor, which indicates padding. Returns: cropped_and_padded_voxels: A voxel grid with shape [end_coordinates[0] - start_coordinates[0], end_coordinates[1] - start_coordinates[1], end_coordinates[2] - start_coordinates[2], end_coordinates[3] - start_coordinates[3]] Raises: ValueError: If requested crop and pad is outside the bounds of what the function supports. """ if len(start_coordinates) != 4: raise ValueError('start_coordinates should be of length 4') if len(end_coordinates) != 4: raise ValueError('end_coordinates should be of length 4') if any([coord <= 0 for coord in end_coordinates]): raise ValueError('Requested end coordinates should be > 0') start_coordinates = tf.convert_to_tensor(start_coordinates, tf.int32) end_coordinates = tf.convert_to_tensor(end_coordinates, tf.int32) # Clip the coordinates to within the voxel grid clipped_start_coordinates = tf.maximum(0, start_coordinates) clipped_end_coordinates = tf.minimum(voxels.shape, end_coordinates) cropped_voxels = tf.slice(voxels, begin=clipped_start_coordinates, size=(clipped_end_coordinates - clipped_start_coordinates)) top_and_left_padding = tf.maximum(0, -start_coordinates) bottom_and_right_padding = tf.maximum(0, end_coordinates - voxels.shape) padding = tf.stack([top_and_left_padding, bottom_and_right_padding], axis=1) return tf.pad(cropped_voxels, padding)
def proto_maml_fc_bias(self, prototypes, zero_pad_to_max_way=False): """Computes the Prototypical MAML fc layer's bias. Args: prototypes: Tensor of shape [num_classes, embedding_size] zero_pad_to_max_way: Whether to zero padd to max num way. Returns: fc_bias: Tensor of shape [num_classes] or [self.logit_dim] when zero_pad_to_max_way is True. """ fc_bias = -tf.square(tf.norm(prototypes, axis=1)) if zero_pad_to_max_way: paddings = [[0, self.logit_dim - tf.shape(fc_bias)[0]]] fc_bias = tf.pad(fc_bias, paddings, 'CONSTANT', constant_values=0) return fc_bias
def fn_normals_single_batch(i): """Function for computing normals for a single batch.""" num_valid_points_i = num_valid_points[i] points_i = points[i, 0:num_valid_points_i, :] if viewpoints is None: viewpoint_i = None else: viewpoint_i = viewpoints[i, :] normals_i = points_to_normals_unbatched( points=points_i, k=k, distance_upper_bound=distance_upper_bound, viewpoint=viewpoint_i, noise_magnitude=noise_magnitude, method=method) return tf.pad(normals_i, paddings=[[0, num_points - num_valid_points_i], [0, 0]])
def compute_semantic_labels(inputs, points_key, box_margin=0.1): """Computes ground-truth semantic labels of the points. If a point falls inside an object box, assigns it to the label of that box. Otherwise the point is assigned to background (unknown) which is label 0. Args: inputs: A dictionary containing points and objects. points_key: A string corresponding to the tensor of point positions in inputs. box_margin: A margin by which object boxes are grown. Useful to make sure points on the object box boundary fall inside the object. Returns: A tf.int32 tensor of size [num_points, 1] containing point semantic labels. Raises: ValueError: If the required object or point keys are not in inputs. """ if points_key not in inputs: raise ValueError(('points_key: %s not in inputs.' % points_key)) if 'objects/shape/dimension' not in inputs: raise ValueError('`objects/shape/dimension` not in inputs.') if 'objects/pose/R' not in inputs: raise ValueError('`objects/pose/R` not in inputs.') if 'objects/pose/t' not in inputs: raise ValueError('`objects/pose/t` not in inputs.') if 'objects/category/label' not in inputs: raise ValueError('`objects/category/label` not in inputs.') point_positions = inputs[points_key] boxes_length = inputs['objects/shape/dimension'][:, 0:1] boxes_width = inputs['objects/shape/dimension'][:, 1:2] boxes_height = inputs['objects/shape/dimension'][:, 2:3] boxes_rotation_matrix = inputs['objects/pose/R'] boxes_center = inputs['objects/pose/t'] boxes_label = tf.expand_dims(inputs['objects/category/label'], axis=1) boxes_label = tf.pad(boxes_label, paddings=[[1, 0], [0, 0]]) points_box_index = box_utils.map_points_to_boxes( points=point_positions, boxes_length=boxes_length, boxes_height=boxes_height, boxes_width=boxes_width, boxes_rotation_matrix=boxes_rotation_matrix, boxes_center=boxes_center, box_margin=box_margin) return tf.gather(boxes_label, points_box_index + 1)
def proto_maml_fc_weights(self, prototypes, zero_pad_to_max_way=False): """Computes the Prototypical MAML fc layer's weights. Args: prototypes: Tensor of shape [num_classes, embedding_size] zero_pad_to_max_way: Whether to zero padd to max num way. Returns: fc_weights: Tensor of shape [embedding_size, num_classes] or [embedding_size, self.logit_dim] when zero_pad_to_max_way is True. """ fc_weights = 2 * prototypes fc_weights = tf.transpose(fc_weights) if zero_pad_to_max_way: paddings = [[0, 0], [0, self.logit_dim - tf.shape(fc_weights)[1]]] fc_weights = tf.pad(fc_weights, paddings, 'CONSTANT', constant_values=0) return fc_weights
def process_example(example_string, image_size, data_augmentation=None): """Processes a single example string. Extracts and processes the image, and ignores the label. We assume that the image has three channels. Args: example_string: str, an Example protocol buffer. image_size: int, desired image size. The extracted image will be resized to `[image_size, image_size]`. data_augmentation: A DataAugmentation object with parameters for perturbing the images. Returns: image_rescaled: the image, resized to `image_size x image_size` and rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values to go beyond this range. """ image_string = tf.parse_single_example(example_string, features={ 'image': tf.FixedLenFeature( [], dtype=tf.string), 'label': tf.FixedLenFeature([], tf.int64) })['image'] image_decoded = tf.image.decode_jpeg(image_string, channels=3) image_resized = tf.image.resize_images( image_decoded, [image_size, image_size], method=tf.image.ResizeMethod.BILINEAR, align_corners=True) image = 2 * (image_resized / 255.0 - 0.5) # Rescale to [-1, 1]. if data_augmentation is not None: if data_augmentation.enable_gaussian_noise: image = image + tf.random_normal( tf.shape(image)) * data_augmentation.gaussian_noise_std if data_augmentation.enable_jitter: j = data_augmentation.jitter_amount paddings = tf.constant([[j, j], [j, j], [0, 0]]) image = tf.pad(image, paddings, 'REFLECT') image = tf.image.random_crop(image, [image_size, image_size, 3]) return image
def _body_fn(i, indices_range, indices): """Computes the indices of the i-th point feature in each segment.""" indices_i = tf.math.unsorted_segment_max( data=indices_range, segment_ids=segment_ids, num_segments=num_segments) indices_i_positive_mask = tf.greater(indices_i, 0) indices_i_positive = tf.boolean_mask(indices_i, indices_i_positive_mask) boolean_mask = tf.scatter_nd( indices=tf.cast( tf.expand_dims(indices_i_positive - 1, axis=1), dtype=tf.int64), updates=tf.ones_like(indices_i_positive, dtype=tf.int32), shape=(n,)) indices_range *= (1 - boolean_mask) indices_i *= tf.cast(indices_i_positive_mask, dtype=tf.int32) indices_i = tf.pad( tf.expand_dims(indices_i, axis=1), paddings=[[0, 0], [i, num_samples_per_voxel - i - 1]]) indices += indices_i i = i + 1 return i, indices_range, indices
def _transfer_object_properties_to_points(inputs): """Sets the object properties for the points that fall inside objects. Args: inputs: A dictionary containing input tensors. """ dic = {} if standard_fields.InputDataFields.objects_class in inputs: dic[standard_fields.InputDataFields.object_class_points] = inputs[ standard_fields.InputDataFields.objects_class] if standard_fields.InputDataFields.objects_center in inputs: dic[standard_fields.InputDataFields.object_center_points] = inputs[ standard_fields.InputDataFields.objects_center] if standard_fields.InputDataFields.objects_length in inputs: dic[standard_fields.InputDataFields.object_length_points] = inputs[ standard_fields.InputDataFields.objects_length] if standard_fields.InputDataFields.objects_height in inputs: dic[standard_fields.InputDataFields.object_height_points] = inputs[ standard_fields.InputDataFields.objects_height] if standard_fields.InputDataFields.objects_width in inputs: dic[standard_fields.InputDataFields.object_width_points] = inputs[ standard_fields.InputDataFields.objects_width] if standard_fields.InputDataFields.objects_rotation_matrix in inputs: dic[standard_fields.InputDataFields. object_rotation_matrix_points] = inputs[ standard_fields.InputDataFields.objects_rotation_matrix] for key, value in dic.items(): if len(value.get_shape().as_list()) == 1: paddings = [[1, 0]] elif len(value.get_shape().as_list()) == 2: paddings = [[1, 0], [0, 0]] elif len(value.get_shape().as_list()) == 3: paddings = [[1, 0], [0, 0], [0, 0]] else: raise ValueError(('Invalid shape for %s' % key)) temp_tensor = tf.pad(value, paddings=paddings) id_mapping = tf.reshape( inputs[standard_fields.InputDataFields.object_instance_id_points], [-1]) inputs[key] = tf.gather(temp_tensor, id_mapping)
def fn(i): """Map function.""" num_valid_points_i = num_valid_points[i] points_i = points[i, :num_valid_points_i, :] features_i = features[i, :num_valid_points_i, :] voxel_features_i, voxel_indices_i, segment_ids_i, voxel_start_location_i = ( pointcloud_to_sparse_voxel_grid_unbatched( points=points_i, features=features_i, grid_cell_size=grid_cell_size, segment_func=segment_func)) num_valid_voxels_i = tf.shape(voxel_features_i)[0] (voxel_features_i, voxel_indices_i, num_valid_voxels_i, segment_ids_i) = _pad_or_clip_voxels( voxel_features=voxel_features_i, voxel_indices=voxel_indices_i, num_valid_voxels=num_valid_voxels_i, segment_ids=segment_ids_i, voxels_pad_or_clip_size=voxels_pad_or_clip_size) segment_ids_i = tf.pad( segment_ids_i, paddings=[[0, num_points - num_valid_points_i]]) return (voxel_features_i, voxel_indices_i, num_valid_voxels_i, segment_ids_i, voxel_start_location_i)
def preprocess(inputs, output_keys=None, is_training=False, using_sequence_dataset=False, num_frame_to_load=1, transform_points_fn=None, image_preprocess_fn_dic=None, images_points_correspondence_fn=None, compute_semantic_labels_fn=None, compute_motion_labels_fn=None, view_names=(), points_key='points', colors_key='colors', normals_key='normals', intensities_key='intensities', elongations_key='elongations', semantic_labels_key='semantic_labels', motion_labels_key='motion_labels', spin_coords_key=None, points_in_image_frame_key=None, num_points_to_randomly_sample=None, x_min_degree_rotation=None, x_max_degree_rotation=None, y_min_degree_rotation=None, y_max_degree_rotation=None, z_min_degree_rotation=None, z_max_degree_rotation=None, points_pad_or_clip_size=None, voxels_pad_or_clip_size=None, voxel_grid_cell_size=(0.1, 0.1, 0.1), num_offset_bins_x=4, num_offset_bins_y=4, num_offset_bins_z=4, point_feature_keys=('point_offsets', ), point_to_voxel_segment_func=tf.math.unsorted_segment_mean, x_random_crop_size=None, y_random_crop_size=None, min_scale_ratio=None, max_scale_ratio=None, semantic_labels_offset=0, ignore_labels=(), remove_unlabeled_images_and_points=False, labeled_view_name=None, only_keep_first_return_lidar_points=False): """Preprocesses a dictionary of `Tensor` inputs. If is_training=True, it will randomly rotate the points around the z axis, and will randomly flip the points with respect to x and/or y axis. Note that the preprocessor function does not correct normal vectors if they exist in the inputs. Note that the preprocessing effects all values of `inputs` that are `Tensors`. Args: inputs: A dictionary of inputs. Each value must be a `Tensor`. output_keys: Either None, or a list of strings containing the keys in the dictionary that is returned by the preprocess function. is_training: Whether we're training or testing. using_sequence_dataset: if true, the inputs will contain scene and multiple frames data. num_frame_to_load: If greater than 1, load multiframe point cloud point positions and its correspondence. transform_points_fn: Fn to transform other frames to a specific frame's coordinate. image_preprocess_fn_dic: Image preprocessing function. Maps view names to their image preprocessing functions. Set it to None, if there are no images to preprocess or you are not interested in preprocesing images. images_points_correspondence_fn: The function that computes correspondence between images and points. compute_semantic_labels_fn: If not None, semantic labels will be computed using this function. compute_motion_labels_fn: If not None, motion labels will be computed using this function. view_names: Names corresponding to 2d views of the scene. points_key: The key used for `points` in the inputs. colors_key: The key used for `colors` in the inputs. normals_key: The key used for 'normals' in the inputs. intensities_key: The key used for 'intensities' in the inputs. elongations_key: The key used for 'elongations' in the inputs. semantic_labels_key: The key used for 'semantic_labels' in the inputs. motion_labels_key: The key used for 'motion_labels' in the inputs. spin_coords_key: The key used for 'spin_coords' in the inputs. In Waymo data, spin_coords is a [num_points, 3] tensor that contains scan_index, shot_index, return_index. In Waymo data, return_index of the first return points is 0. points_in_image_frame_key: A string that identifies the tensor that contains the points_in_image_frame tensor. If None, it won't be used. num_points_to_randomly_sample: Number of points to randomly sample. If None, it will keep the original points and does not perform sampling. x_min_degree_rotation: Min degree of rotation around the x axis. x_max_degree_rotation: Max degree of ratation around the x axis. y_min_degree_rotation: Min degree of rotation around the y axis. y_max_degree_rotation: Max degree of ratation around the y axis. z_min_degree_rotation: Min degree of rotation around the z axis. z_max_degree_rotation: Max degree of ratation around the z axis. points_pad_or_clip_size: Number of target points to pad or clip to. If None, it will not perform the point padding. voxels_pad_or_clip_size: Number of target voxels to pad or clip to. If None, it will not perform the voxel padding. voxel_grid_cell_size: A three dimensional tuple determining the voxel grid size. num_offset_bins_x: Number of bins for point offsets in x direction. num_offset_bins_y: Number of bins for point offsets in y direction. num_offset_bins_z: Number of bins for point offsets in z direction. point_feature_keys: The keys used to form the voxel features. point_to_voxel_segment_func: The function used to aggregate the features of the points that fall in the same voxel. x_random_crop_size: Size of the random crop in x dimension. If None, random crop will not take place on x dimension. y_random_crop_size: Size of the random crop in y dimension. If None, random crop will not take place on y dimension. min_scale_ratio: Minimum scale ratio. Used for scaling point cloud. max_scale_ratio: Maximum scale ratio. Used for scaling point cloud. semantic_labels_offset: An integer offset that will be added to labels. ignore_labels: A tuple containing labels that should be ignored when computing the loss and metrics. remove_unlabeled_images_and_points: If True, removes the images that are not labeled and also removes the points that are associated with those images. labeled_view_name: The name of the view that is labeled, otherwise None. only_keep_first_return_lidar_points: If True, we only keep the first return lidar points. Returns: The mean subtracted points with an optional rotation applied. Raises: ValueError: if `inputs` doesn't contain the points_key. ValueError: if `points_in_image_frame` does not have rank 3. """ inputs = dict(inputs) if using_sequence_dataset: all_frame_inputs = inputs scene = all_frame_inputs['scene'] frame1 = all_frame_inputs['frame1'] frame_start_index = all_frame_inputs['frame_start_index'] inputs = dict( all_frame_inputs['frame0'] ) # so that the following processing code can be unchanged. # Initializing empty dictionary for mesh, image, indices_2d and non tensor # inputs. non_tensor_inputs = {} view_image_inputs = {} view_indices_2d_inputs = {} mesh_inputs = {} if image_preprocess_fn_dic is None: image_preprocess_fn_dic = {} # Convert all float64 to float32 and all int64 to int32. for key in sorted(inputs): if isinstance(inputs[key], tf.Tensor): if inputs[key].dtype == tf.float64: inputs[key] = tf.cast(inputs[key], dtype=tf.float32) if inputs[key].dtype == tf.int64: inputs[key] = tf.cast(inputs[key], dtype=tf.int32) if points_key in inputs: inputs[standard_fields.InputDataFields. point_positions] = inputs[points_key] if colors_key is not None and colors_key in inputs: inputs[ standard_fields.InputDataFields.point_colors] = inputs[colors_key] if normals_key is not None and normals_key in inputs: inputs[standard_fields.InputDataFields. point_normals] = inputs[normals_key] if intensities_key is not None and intensities_key in inputs: inputs[standard_fields.InputDataFields. point_intensities] = inputs[intensities_key] if elongations_key is not None and elongations_key in inputs: inputs[standard_fields.InputDataFields. point_elongations] = inputs[elongations_key] if semantic_labels_key is not None and semantic_labels_key in inputs: inputs[standard_fields.InputDataFields. object_class_points] = inputs[semantic_labels_key] if motion_labels_key is not None and motion_labels_key in inputs: inputs[standard_fields.InputDataFields. object_flow_points] = inputs[motion_labels_key] if spin_coords_key is not None and spin_coords_key in inputs: inputs[standard_fields.InputDataFields. point_spin_coordinates] = inputs[spin_coords_key] # Acquire point / image correspondences. if images_points_correspondence_fn is not None: fn_outputs = images_points_correspondence_fn(inputs) if 'points_position' in fn_outputs: inputs[standard_fields.InputDataFields. point_positions] = fn_outputs['points_position'] if 'points_intensity' in fn_outputs and intensities_key is not None: inputs[standard_fields.InputDataFields. point_intensities] = fn_outputs['points_intensity'] if 'points_elongation' in fn_outputs and elongations_key is not None: inputs[standard_fields.InputDataFields. point_elongations] = fn_outputs['points_elongation'] if 'points_label' in fn_outputs and semantic_labels_key is not None: inputs[standard_fields.InputDataFields. object_class_points] = fn_outputs['points_label'] if 'view_images' in fn_outputs: for key in sorted(fn_outputs['view_images']): if len(fn_outputs['view_images'][key].shape) != 4: raise ValueError(('%s image should have rank 4.' % key)) view_image_inputs = fn_outputs['view_images'] if 'view_indices_2d' in fn_outputs: for key in sorted(fn_outputs['view_indices_2d']): if len(fn_outputs['view_indices_2d'][key].shape) != 3: raise ValueError( ('%s indices_2d should have rank 3.' % key)) view_indices_2d_inputs = fn_outputs['view_indices_2d'] else: if points_in_image_frame_key is not None: inputs['rgb_view/features'] = inputs['image'] inputs['rgb_view/indices_2d'] = inputs[points_in_image_frame_key] if len(inputs['rgb_view/indices_2d'].shape) != 3: raise ValueError('`points_in_image_frame` should have rank 3.') frame0 = inputs.copy() if num_frame_to_load > 1: point_positions_list = [ frame0[standard_fields.InputDataFields.point_positions] ] if view_indices_2d_inputs: view_indices_2d_list = [view_indices_2d_inputs[view_names[0]]] frame_source_list = [ tf.zeros([ tf.shape( frame0[standard_fields.InputDataFields.point_positions])[0] ], tf.int32) ] for i in range(1, num_frame_to_load): target_frame_key = 'frame' + str(i) if images_points_correspondence_fn is not None: frame_i = images_points_correspondence_fn( all_frame_inputs[target_frame_key]) else: raise ValueError( 'images_points_correspondence_fn is needed for loading multi-frame pointclouds.' ) transformed_point_positions = transform_points_fn( scene, frame_i['points_position'], frame_start_index, i + frame_start_index) point_positions_list.append(transformed_point_positions) if view_indices_2d_inputs: view_indices_2d_list.append( frame_i['view_indices_2d'][view_names[0]]) frame_source_list.append( tf.ones([tf.shape(transformed_point_positions)[0]], tf.int32) * i) # add multi-frame info to override inputs and view_indices_2d_inputs inputs[standard_fields.InputDataFields. point_frame_index] = tf.expand_dims(tf.concat(frame_source_list, axis=0), axis=1) inputs[standard_fields.InputDataFields.point_positions] = tf.concat( point_positions_list, axis=0) if view_indices_2d_inputs: view_indices_2d_inputs[view_names[0]] = tf.concat( view_indices_2d_list, axis=1) # Validate inputs. if standard_fields.InputDataFields.point_positions not in inputs: raise ValueError('`inputs` must contain a point_positions') if inputs[ standard_fields.InputDataFields.point_positions].shape.ndims != 2: raise ValueError('points must be of rank 2.') if inputs[standard_fields.InputDataFields.point_positions].shape[1] != 3: raise ValueError('point should be 3 dimensional.') # Remove normal nans. if standard_fields.InputDataFields.point_normals in inputs: inputs[standard_fields.InputDataFields.point_normals] = tf.where( tf.math.is_nan( inputs[standard_fields.InputDataFields.point_normals]), tf.zeros_like( inputs[standard_fields.InputDataFields.point_normals]), inputs[standard_fields.InputDataFields.point_normals]) # Compute semantic labels if compute_semantic_labels_fn is not None # An example is when the ground-truth contains 3d object boxes and not per # point labels. This would be a function that infers point labels from boxes. if compute_semantic_labels_fn is not None: inputs[standard_fields.InputDataFields. object_class_points] = compute_semantic_labels_fn( inputs=frame0, points_key=standard_fields.InputDataFields.point_positions) if compute_motion_labels_fn is not None: inputs[standard_fields.InputDataFields. object_flow_points] = compute_motion_labels_fn( scene=scene, frame0=frame0, frame1=frame1, frame_start_index=frame_start_index, points_key=standard_fields.InputDataFields.point_positions) # Splitting inputs to {view_image_inputs, # view_indices_2d_inputs, # mesh_inputs, # non_tensor_inputs} mesh_keys = [] for key in [ standard_fields.InputDataFields.point_positions, standard_fields.InputDataFields.point_colors, standard_fields.InputDataFields.point_normals, standard_fields.InputDataFields.point_intensities, standard_fields.InputDataFields.point_elongations, standard_fields.InputDataFields.object_class_points, standard_fields.InputDataFields.point_spin_coordinates, standard_fields.InputDataFields.object_flow_points, standard_fields.InputDataFields.point_frame_index, ]: if key is not None and key in inputs: mesh_keys.append(key) view_image_names = [('%s/features' % key) for key in view_names] view_indices_2d_names = [('%s/indices_2d' % key) for key in view_names] # Additional key collecting for k, v in six.iteritems(inputs): if k in view_image_names: view_image_inputs[k] = v elif k in view_indices_2d_names: view_indices_2d_inputs[k] = v elif k in mesh_keys: if num_frame_to_load > 1: pad_size = tf.shape( inputs[standard_fields.InputDataFields. point_positions])[0] - tf.shape(v)[0] if k == standard_fields.InputDataFields.object_class_points: pad_value = -1 else: pad_value = 0 v = tf.pad(v, [[0, pad_size], [0, 0]], constant_values=pad_value) mesh_inputs[k] = v else: non_tensor_inputs[k] = v # Remove points that are not in the lidar first return (optional) if only_keep_first_return_lidar_points: _remove_second_return_lidar_points( mesh_inputs=mesh_inputs, view_indices_2d_inputs=view_indices_2d_inputs) # Randomly sample points preprocessor_utils.randomly_sample_points( mesh_inputs=mesh_inputs, view_indices_2d_inputs=view_indices_2d_inputs, target_num_points=num_points_to_randomly_sample) # Add weights if it does not exist in inputs. The weight of the points with # label in `ignore_labels` is set to 0. This helps the loss and metrics to # ignore those labels. use_weights = ( standard_fields.InputDataFields.object_class_points in mesh_inputs or standard_fields.InputDataFields.object_flow_points in mesh_inputs) if use_weights: if num_frame_to_load > 1: num_valid_points_frame0 = tf.shape( frame0[standard_fields.InputDataFields.point_positions])[0] num_additional_frame_points = tf.shape( mesh_inputs[standard_fields.InputDataFields. object_class_points])[0] - num_valid_points_frame0 weights = tf.concat([ tf.ones([num_valid_points_frame0, 1], tf.float32), tf.zeros([num_additional_frame_points, 1], tf.float32) ], axis=0) else: weights = tf.ones_like(mesh_inputs[ standard_fields.InputDataFields.object_class_points], dtype=tf.float32) if standard_fields.InputDataFields.object_class_points in mesh_inputs: mesh_inputs[ standard_fields.InputDataFields.object_class_points] = tf.cast( mesh_inputs[ standard_fields.InputDataFields.object_class_points], dtype=tf.int32) for ignore_label in ignore_labels: weights *= tf.cast(tf.not_equal( mesh_inputs[ standard_fields.InputDataFields.object_class_points], ignore_label), dtype=tf.float32) mesh_inputs[ standard_fields.InputDataFields.point_loss_weights] = weights mesh_inputs[standard_fields.InputDataFields. object_class_points] += semantic_labels_offset # We normalize the intensities and elongations to be in a smaller range. if standard_fields.InputDataFields.point_intensities in mesh_inputs: mesh_inputs[standard_fields.InputDataFields. point_intensities] = change_intensity_range( intensities=mesh_inputs[ standard_fields.InputDataFields.point_intensities]) if standard_fields.InputDataFields.point_elongations in mesh_inputs: mesh_inputs[ standard_fields.InputDataFields.point_elongations] = (tf.cast( mesh_inputs[standard_fields.InputDataFields.point_elongations], dtype=tf.float32) * 2.0 / 255.0) - 1.0 # Random scale the points. if min_scale_ratio is not None and max_scale_ratio is not None: scale_ratio = tf.random.uniform([], minval=min_scale_ratio, maxval=max_scale_ratio, dtype=tf.float32) mesh_inputs[ standard_fields.InputDataFields.point_positions] *= scale_ratio if standard_fields.InputDataFields.object_flow_points in mesh_inputs: mesh_inputs[standard_fields.InputDataFields. object_flow_points] *= scale_ratio # Random crop the points. randomly_crop_points(mesh_inputs=mesh_inputs, view_indices_2d_inputs=view_indices_2d_inputs, x_random_crop_size=x_random_crop_size, y_random_crop_size=y_random_crop_size) # If training, pick the best labeled image and points that project to it. # In many datasets, only one image is labeled anyways. if remove_unlabeled_images_and_points: pick_labeled_image(mesh_inputs=mesh_inputs, view_image_inputs=view_image_inputs, view_indices_2d_inputs=view_indices_2d_inputs, view_name=labeled_view_name) # Process images. preprocessor_utils.preprocess_images( view_image_inputs=view_image_inputs, view_indices_2d_inputs=view_indices_2d_inputs, image_preprocess_fn_dic=image_preprocess_fn_dic, is_training=is_training) # Record the original points. original_points = mesh_inputs[ standard_fields.InputDataFields.point_positions] if standard_fields.InputDataFields.point_colors in mesh_inputs: original_colors = mesh_inputs[ standard_fields.InputDataFields.point_colors] if standard_fields.InputDataFields.point_normals in mesh_inputs: original_normals = mesh_inputs[ standard_fields.InputDataFields.point_normals] # Update feature visibility count. if 'feature_visibility_count' in mesh_inputs: mesh_inputs['feature_visibility_count'] = tf.maximum( mesh_inputs['feature_visibility_count'], 1) mesh_inputs['features'] /= tf.cast( mesh_inputs['feature_visibility_count'], dtype=tf.float32) # Subtract mean from points. mean_points = tf.reduce_mean( mesh_inputs[standard_fields.InputDataFields.point_positions], axis=0) mesh_inputs[ standard_fields.InputDataFields.point_positions] -= tf.expand_dims( mean_points, axis=0) # Rotate points randomly. if standard_fields.InputDataFields.point_normals in mesh_inputs: normals = mesh_inputs[standard_fields.InputDataFields.point_normals] else: normals = None if standard_fields.InputDataFields.object_flow_points in mesh_inputs: motions = mesh_inputs[ standard_fields.InputDataFields.object_flow_points] else: motions = None (mesh_inputs[standard_fields.InputDataFields.point_positions], rotated_normals, rotated_motions) = rotate_randomly( points=mesh_inputs[standard_fields.InputDataFields.point_positions], normals=normals, motions=motions, x_min_degree_rotation=x_min_degree_rotation, x_max_degree_rotation=x_max_degree_rotation, y_min_degree_rotation=y_min_degree_rotation, y_max_degree_rotation=y_max_degree_rotation, z_min_degree_rotation=z_min_degree_rotation, z_max_degree_rotation=z_max_degree_rotation) # Random flipping in x and y directions. (mesh_inputs[standard_fields.InputDataFields.point_positions], flipped_normals, flipped_motions) = flip_randomly_points_and_normals_motions( points=mesh_inputs[standard_fields.InputDataFields.point_positions], normals=rotated_normals, motions=rotated_motions, is_training=is_training) if standard_fields.InputDataFields.point_normals in mesh_inputs: mesh_inputs[ standard_fields.InputDataFields.point_normals] = flipped_normals if standard_fields.InputDataFields.object_flow_points in mesh_inputs: mesh_inputs[standard_fields.InputDataFields. object_flow_points] = flipped_motions # Normalize RGB to [-1.0, 1.0]. if standard_fields.InputDataFields.point_colors in mesh_inputs: mesh_inputs[standard_fields.InputDataFields.point_colors] = tf.cast( mesh_inputs[standard_fields.InputDataFields.point_colors], dtype=tf.float32) mesh_inputs[standard_fields.InputDataFields.point_colors] *= (2.0 / 255.0) mesh_inputs[standard_fields.InputDataFields.point_colors] -= 1.0 # Add original points to mesh inputs. mesh_inputs[standard_fields.InputDataFields. point_positions_original] = original_points if standard_fields.InputDataFields.point_colors in mesh_inputs: mesh_inputs[standard_fields.InputDataFields. point_colors_original] = original_colors if standard_fields.InputDataFields.point_normals in mesh_inputs: mesh_inputs[standard_fields.InputDataFields. point_normals_original] = original_normals # Pad or clip the point tensors. pad_or_clip(mesh_inputs=mesh_inputs, view_indices_2d_inputs=view_indices_2d_inputs, pad_or_clip_size=points_pad_or_clip_size) if num_frame_to_load > 1: # Note: num_valid_points is the sum of 'num_points_per_fram' for now. # num_points_per_frame is each frame's valid num of points. # TODO(huangrui): if random sampling is called earlier, the count here # is not guaranteed to be in order. need sorting. if num_points_to_randomly_sample is not None: raise ValueError( 'randomly sample is not compatible with padding multi frame point clouds yet!' ) _, _, mesh_inputs[standard_fields.InputDataFields. num_valid_points_per_frame] = tf.unique_with_counts( tf.reshape( mesh_inputs[standard_fields.InputDataFields. point_frame_index], [-1])) if points_pad_or_clip_size is not None: padded_points = tf.where_v2( tf.greater( points_pad_or_clip_size, mesh_inputs[ standard_fields.InputDataFields.num_valid_points]), points_pad_or_clip_size - mesh_inputs[standard_fields.InputDataFields.num_valid_points], 0) # Correct the potential unique count error from optionally padded 0s point # frame index. mesh_inputs[ standard_fields.InputDataFields. num_valid_points_per_frame] -= tf.pad( tf.expand_dims(padded_points, 0), [[ 0, tf.shape(mesh_inputs[standard_fields.InputDataFields. num_valid_points_per_frame])[0] - 1 ]]) # Putting back the dictionaries together processed_inputs = mesh_inputs.copy() processed_inputs.update(non_tensor_inputs) for key in sorted(view_image_inputs): processed_inputs[('%s/features' % key)] = view_image_inputs[key] for key in sorted(view_indices_2d_inputs): processed_inputs[('%s/indices_2d' % key)] = view_indices_2d_inputs[key] # Create features that do not exist if 'point_offsets' in point_feature_keys: preprocessor_utils.add_point_offsets( inputs=processed_inputs, voxel_grid_cell_size=voxel_grid_cell_size) if 'point_offset_bins' in point_feature_keys: preprocessor_utils.add_point_offset_bins( inputs=processed_inputs, voxel_grid_cell_size=voxel_grid_cell_size, num_bins_x=num_offset_bins_x, num_bins_y=num_offset_bins_y, num_bins_z=num_offset_bins_z) # Voxelize point features preprocessor_utils.voxelize_point_features( inputs=processed_inputs, voxels_pad_or_clip_size=voxels_pad_or_clip_size, voxel_grid_cell_size=voxel_grid_cell_size, point_feature_keys=point_feature_keys, point_to_voxel_segment_func=point_to_voxel_segment_func, num_frame_to_load=num_frame_to_load) # Voxelize point / image correspondence indices preprocessor_utils.voxelize_point_to_view_correspondences( inputs=processed_inputs, view_indices_2d_inputs=view_indices_2d_inputs, voxels_pad_or_clip_size=voxels_pad_or_clip_size, voxel_grid_cell_size=voxel_grid_cell_size) # Voxelizing the semantic labels preprocessor_utils.voxelize_semantic_labels( inputs=processed_inputs, voxels_pad_or_clip_size=voxels_pad_or_clip_size, voxel_grid_cell_size=voxel_grid_cell_size) # Voxelizing the loss weights preprocessor_utils.voxelize_property_tensor( inputs=processed_inputs, point_tensor_key=standard_fields.InputDataFields.point_loss_weights, corresponding_voxel_tensor_key=standard_fields.InputDataFields. voxel_loss_weights, voxels_pad_or_clip_size=voxels_pad_or_clip_size, voxel_grid_cell_size=voxel_grid_cell_size, segment_func=tf.math.unsorted_segment_max) # Voxelizing the object flow if standard_fields.InputDataFields.object_flow_points in processed_inputs: preprocessor_utils.voxelize_property_tensor( inputs=processed_inputs, point_tensor_key=standard_fields.InputDataFields. object_flow_points, corresponding_voxel_tensor_key='object_flow_voxels_max', voxels_pad_or_clip_size=voxels_pad_or_clip_size, voxel_grid_cell_size=voxel_grid_cell_size, segment_func=tf.math.unsorted_segment_max) preprocessor_utils.voxelize_property_tensor( inputs=processed_inputs, point_tensor_key=standard_fields.InputDataFields. object_flow_points, corresponding_voxel_tensor_key='object_flow_voxels_min', voxels_pad_or_clip_size=voxels_pad_or_clip_size, voxel_grid_cell_size=voxel_grid_cell_size, segment_func=tf.math.unsorted_segment_min) processed_inputs[standard_fields.InputDataFields. object_flow_voxels] = processed_inputs[ 'object_flow_voxels_max'] + processed_inputs[ 'object_flow_voxels_min'] if num_frame_to_load > 1: mesh_inputs[ standard_fields.InputDataFields.num_valid_points] = mesh_inputs[ standard_fields.InputDataFields.num_valid_points_per_frame][0] # Filter preprocessed_inputs by output_keys if it is not None. if output_keys is not None: processed_inputs = { k: v for k, v in six.iteritems(processed_inputs) if k in output_keys } return processed_inputs
def __call__(self, example_string): """Processes a single example string. Extracts and processes the image, and ignores the label. We assume that the image has three channels. Args: example_string: str, an Example protocol buffer. Returns: image_rescaled: the image, resized to `image_size x image_size` and rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values to go beyond this range. """ image_string = tf.parse_single_example( example_string, features={ 'image': tf.FixedLenFeature([], dtype=tf.string), 'label': tf.FixedLenFeature([], tf.int64) })['image'] image_decoded = tf.image.decode_image(image_string, channels=3) image_decoded.set_shape([None, None, 3]) image_resized = tf.image.resize_images( image_decoded, [self.image_size, self.image_size], method=tf.image.ResizeMethod.BILINEAR, align_corners=True) image = tf.cast(image_resized, tf.float32) if self.data_augmentation is not None: if self.data_augmentation.enable_random_brightness: delta = self.data_augmentation.random_brightness_delta image = tf.image.random_brightness(image, delta) if self.data_augmentation.enable_random_saturation: delta = self.data_augmentation.random_saturation_delta image = tf.image.random_saturation(image, 1 - delta, 1 + delta) if self.data_augmentation.enable_random_contrast: delta = self.data_augmentation.random_contrast_delta image = tf.image.random_contrast(image, 1 - delta, 1 + delta) if self.data_augmentation.enable_random_hue: delta = self.data_augmentation.random_hue_delta image = tf.image.random_hue(image, delta) if self.data_augmentation.enable_random_flip: image = tf.image.random_flip_left_right(image) image = 2 * (image / 255.0 - 0.5) # Rescale to [-1, 1]. if self.data_augmentation is not None: if self.data_augmentation.enable_gaussian_noise: image = image + tf.random_normal(tf.shape( image)) * self.data_augmentation.gaussian_noise_std if self.data_augmentation.enable_jitter: j = self.data_augmentation.jitter_amount paddings = tf.constant([[j, j], [j, j], [0, 0]]) image = tf.pad(image, paddings, 'REFLECT') image = tf.image.random_crop( image, [self.image_size, self.image_size, 3]) return image
def geometric_augmentation(images, flow = None, mask = None, crop_height = 640, crop_width = 640, probability_flip_left_right = 0.5, probability_flip_up_down = 0.1, probability_scale = 0.8, probability_relative_scale = 0., probability_stretch = 0.8, probability_rotation = 0.0, probability_relative_rotation = 0.0, probability_crop_offset = 0.0, min_bound_scale = -0.2, max_bound_scale = 0.6, max_strech_scale = 0.2, min_bound_relative_scale = -0.1, max_bound_relative_scale = 0.1, max_rotation_deg = 15, max_relative_rotation_deg = 3, max_relative_crop_offset = 5, return_full_scale=False): """Applies geometric augmentations to an image pair and corresponding flow. Args: images: Image pair of shape [2, height, width, channels]. flow: Corresponding forward flow field of shape [height, width, 2]. mask: Mask indicating which positions in the flow field hold valid flow vectors of shape [height, width, 1]. Non-valid poisitions are encoded with 0, valid positions with 1. crop_height: Height of the final augmented output. crop_width: Width of the final augmented output. probability_flip_left_right: Probability of applying left/right flip. probability_flip_up_down: Probability of applying up/down flip probability_scale: Probability of applying scale augmentation. probability_relative_scale: Probability of applying scale augmentation to only the second frame of the the image pair. probability_stretch: Probability of applying stretch augmentation (scale without keeping the aspect ratio). probability_rotation: Probability of applying rotation augmentation. probability_relative_rotation: Probability of applying rotation augmentation to only the second frame of the the image pair. probability_crop_offset: Probability of applying a relative offset while cropping. min_bound_scale: Defines the smallest possible scaling factor as 2**min_bound_scale. max_bound_scale: Defines the largest possible scaling factor as 2**max_bound_scale. max_strech_scale: Defines the smallest and largest possible streching factor as 2**-max_strech_scale and 2**max_strech_scale. min_bound_relative_scale: Defines the smallest possible scaling factor for the relative scaling as 2**min_bound_relative_scale. max_bound_relative_scale: Defines the largest possible scaling factor for the relative scaling as 2**max_bound_relative_scale. max_rotation_deg: Defines the maximum angle of rotation in degrees. max_relative_rotation_deg: Defines the maximum angle of rotation in degrees for the relative rotation. max_relative_crop_offset: Defines the maximum relative offset in pixels for cropping. return_full_scale: bool. If this is passed, the full size images will be returned in addition to the geometrically augmented (cropped and / or resized) images. In addition to the resized images, the crop height, width, and any padding applied will be returned. Returns: if return_full_scale is False: Augmented images, flow and mask (if not None). if return_full_scale is True: Augmented images, flow, mask, full_size_images, crop_h, crop_w, pad_h, and pad_w. """ # apply geometric augmentation if probability_flip_left_right > 0: images, flow, mask = random_flip_left_right( images, flow, mask, probability_flip_left_right) if probability_flip_up_down > 0: images, flow, mask = random_flip_up_down( images, flow, mask, probability_flip_up_down) if probability_scale > 0 or probability_stretch > 0: images, flow, mask = random_scale( images, flow, mask, min_scale=min_bound_scale, max_scale=max_bound_scale, max_strech=max_strech_scale, probability_scale=probability_scale, probability_strech=probability_stretch) if probability_relative_scale > 0: images, flow, mask = random_scale_second( images, flow, mask, min_scale=min_bound_relative_scale, max_scale=max_bound_relative_scale, probability_scale=probability_relative_scale) if probability_rotation > 0: images, flow, mask = random_rotation( images, flow, mask, probability=probability_rotation, max_rotation=max_rotation_deg, not_empty_crop=True) if probability_relative_rotation > 0: images, flow, mask = random_rotation_second( images, flow, mask, probability=probability_relative_rotation, max_rotation=max_relative_rotation_deg, not_empty_crop=True) images_uncropped = images images, flow, mask, offset_h, offset_w = random_crop( images, flow, mask, crop_height, crop_width, relative_offset=max_relative_crop_offset, probability_crop_offset=probability_crop_offset) # Add 100 / 200 pixels to crop height / width for full scale warp pad_to_size_h = crop_height + 200 pad_to_size_w = crop_width + 400 if return_full_scale: if pad_to_size_w: uncropped_shape = tf.shape(images_uncropped) if images.shape[1] > uncropped_shape[1] or images.shape[ 2] > uncropped_shape[2]: images_uncropped = images uncropped_shape = tf.shape(images_uncropped) offset_h = tf.zeros_like(offset_h) offset_w = tf.zeros_like(offset_w) if uncropped_shape[1] > pad_to_size_h: crop_ht = offset_h - (200 // 2) crop_hb = offset_h + crop_height + (200 // 2) crop_hb += tf.maximum(0, -crop_ht) crop_ht -= tf.maximum(0, -(uncropped_shape[1] - crop_hb)) crop_ht = tf.maximum(crop_ht, 0) crop_hb = tf.minimum(crop_hb, uncropped_shape[1]) offset_h -= crop_ht images_uncropped = images_uncropped[:, crop_ht:crop_hb, :, :] if uncropped_shape[2] > pad_to_size_w: crop_wt = offset_w - (400 // 2) crop_wb = offset_w + crop_width + (400 // 2) crop_wb += tf.maximum(0, -crop_wt) crop_wt -= tf.maximum(0, -(uncropped_shape[2] - crop_wb)) crop_wt = tf.maximum(crop_wt, 0) crop_wb = tf.minimum(crop_wb, uncropped_shape[2]) offset_w -= crop_wt images_uncropped = images_uncropped[:, :, crop_wt:crop_wb, :] uncropped_shape = tf.shape(images_uncropped) # remove remove_pixels_w from the width while keeping the crop centered pad_h = pad_to_size_h - uncropped_shape[1] pad_w = pad_to_size_w - uncropped_shape[2] with tf.control_dependencies([ tf.compat.v1.assert_greater_equal(pad_h, 0), tf.compat.v1.assert_greater_equal(pad_w, 0) ]): images_uncropped = tf.pad(images_uncropped, [[0, 0], [pad_h, 0], [pad_w, 0], [0, 0]]) images_uncropped = tf.ensure_shape(images_uncropped, [2, pad_to_size_h, pad_to_size_w, 3]) return images, flow, mask, images_uncropped, offset_h, offset_w, pad_h, pad_w return images, flow, mask
def train_eval( load_root_dir, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args safety_critic_joint_fc_layers=None, safety_critic_lr=3e-4, safety_critic_bias_init_val=None, safety_critic_kernel_scale=None, n_envs=None, target_safety=0.2, fail_weight=None, # Params for train num_global_steps=10000, batch_size=256, # Params for eval run_eval=False, eval_metrics=[], num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, debug_summaries=False, seed=None): if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format( agent_class) agent_class = ALGOS.get(agent_class) train_ckpt_dir = osp.join(load_root_dir, 'train') rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer') py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(py_env) if monitor: vid_path = os.path.join(load_root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) if run_eval: eval_dir = os.path.join(load_root_dir, 'eval') n_envs = n_envs or num_eval_episodes eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric( prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs) ] + [ tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name)) for m in eval_metrics ] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: eval_tf_env.pyenv.seed(seeds) except: pass global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=False) else: tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=False) collect_data_spec = tf_agent.collect_data_spec replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=1000000) replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer) tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent) if agent_class in SAFETY_AGENTS: target_safety = target_safety or tf_agent._target_safety loaded_train_steps = global_step.numpy() logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir, loaded_train_steps) global_step.assign(0) tf.summary.experimental.set_step(global_step) thresholds = [target_safety, 0.5] sc_metrics = [ tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds) ] if seed: tf.compat.v1.set_random_seed(seed) summaries_flush_secs = 10 timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S') offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp) config_saver = gin.tf.GinConfigSaverHook(offline_train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() sc_summary_writer = tf.compat.v2.summary.create_file_writer( offline_train_dir, flush_millis=summaries_flush_secs * 1000) sc_summary_writer.set_as_default() if safety_critic_kernel_scale is not None: ki = tf.compat.v1.variance_scaling_initializer( scale=safety_critic_kernel_scale, mode='fan_in', distribution='truncated_normal') else: ki = tf.compat.v1.keras.initializers.VarianceScaling( scale=1. / 3., mode='fan_in', distribution='uniform') if safety_critic_bias_init_val is not None: bi = tf.constant_initializer(safety_critic_bias_init_val) else: bi = None sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=safety_critic_joint_fc_layers, kernel_initializer=ki, value_bias_initializer=bi, name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') optimizer = tf.keras.optimizers.Adam(safety_critic_lr) sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic') sc_checkpointer = common.Checkpointer( ckpt_dir=sc_net_off_ckpt_dir, safety_critic=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=optimizer, global_step=global_step, max_to_keep=5) sc_checkpointer.initialize_or_restore() resample_counter = py_metrics.CounterMetric('ActionResampleCounter') eval_policy = agents.SafeActorPolicyRSVar( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_net, safety_critic_network=sc_net_off, safety_threshold=target_safety, resample_counter=resample_counter, training=True) dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2, sample_batch_size=batch_size // 2).prefetch(3) data = iter(dataset) full_data = replay_buffer.gather_all() fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool) fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, fail_mask), full_data) init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data) before_fail_mask = tf.roll(fail_mask, [-1], axis=[1]) after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1]) before_fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data) after_init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, after_init_mask), full_data) filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask)) filter_mask = tf.pad( filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]]) n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy() failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=n_failures, dataset_window_shift=1) data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask) sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size // 2, num_steps=2).prefetch(3) neg_data = iter(sc_dataset_neg) get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0] eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step, after_init_step, get_action) losses = [] mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss') target_update = train_utils.get_target_updater(sc_net_off, target_sc_net_off) with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): while global_step.numpy() < num_global_steps: pos_experience, _ = next(data) neg_experience, _ = next(neg_data) exp = data_utils.concat_batches(pos_experience, neg_experience, collect_data_spec) boundary_mask = tf.logical_not(exp.is_boundary()[:, 0]) exp = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, boundary_mask), exp) safe_rew = exp.observation['task_agn_rew'][:, 1] if fail_weight: weights = tf.where(tf.cast(safe_rew, tf.bool), fail_weight / 0.5, (1 - fail_weight) / 0.5) else: weights = None train_loss, sc_loss, lam_loss = train_step( exp, safe_rew, tf_agent, sc_net=sc_net_off, target_sc_net=target_sc_net_off, metrics=sc_metrics, weights=weights, target_safety=target_safety, optimizer=optimizer, target_update=target_update, debug_summaries=debug_summaries) global_step.assign_add(1) global_step_val = global_step.numpy() losses.append( (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy())) mean_loss(train_loss) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='sc_loss', data=sc_loss, step=global_step_val) tf.compat.v2.summary.scalar(name='lam_loss', data=lam_loss, step=global_step_val) if global_step_val % summary_interval == 0: tf.compat.v2.summary.scalar(name=mean_loss.name, data=mean_loss.result(), step=global_step_val) if global_step_val % summary_interval == 0: with tf.name_scope('Metrics'): for metric in sc_metrics: if len(tf.squeeze(metric.result()).shape) == 0: tf.compat.v2.summary.scalar(name=metric.name, data=metric.result(), step=global_step_val) else: fmt_str = '_{}'.format(thresholds[0]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[0], step=global_step_val) fmt_str = '_{}'.format(thresholds[1]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[1], step=global_step_val) metric.reset_states() if global_step_val % eval_interval == 0: eval_sc(sc_net_off, step=global_step_val) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries( train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and global_step_val % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action( monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug( 'saved rollout at timestep %d, rollout length: %d, %4.2f sec', global_step_val, ep_len, time.time() - monitor_start) if global_step_val % train_checkpoint_interval == 0: sc_checkpointer.save(global_step=global_step_val)