def fn(): for field in standard_fields.get_input_object_fields(): if field in inputs: inputs[field] = tf.boolean_mask(inputs[field], valid_mask) for field in standard_fields.get_output_object_fields(): if field in outputs: outputs[field] = tf.boolean_mask(outputs[field], valid_mask) return _box_corner_distance_loss( loss_type=loss_type, is_balanced=is_balanced, input_boxes_length=inputs[ standard_fields.InputDataFields.objects_length], input_boxes_height=inputs[ standard_fields.InputDataFields.objects_height], input_boxes_width=inputs[standard_fields.InputDataFields.objects_width], input_boxes_center=inputs[ standard_fields.InputDataFields.objects_center], input_boxes_rotation_matrix=inputs[ standard_fields.InputDataFields.objects_rotation_matrix], input_boxes_instance_id=inputs[ standard_fields.InputDataFields.objects_instance_id], output_boxes_length=outputs[ standard_fields.DetectionResultFields.objects_length], output_boxes_height=outputs[ standard_fields.DetectionResultFields.objects_height], output_boxes_width=outputs[ standard_fields.DetectionResultFields.objects_width], output_boxes_center=outputs[ standard_fields.DetectionResultFields.objects_center], output_boxes_rotation_matrix=outputs[ standard_fields.DetectionResultFields.objects_rotation_matrix], delta=delta)
def update_state(self, inputs, outputs): """Function that updates the metric state at each example. Args: inputs: A dictionary containing input tensors. outputs: A dictionary containing output tensors. Returns: Update op. """ detections_score = tf.reshape( outputs[standard_fields.DetectionResultFields.objects_score], [-1]) detections_class = tf.reshape( outputs[standard_fields.DetectionResultFields.objects_class], [-1]) num_detections = tf.shape(detections_score)[0] detections_instance_mask = tf.reshape( outputs[ standard_fields.DetectionResultFields.instance_segments_voxel_mask], [num_detections, -1]) gt_class = tf.reshape(inputs[standard_fields.InputDataFields.objects_class], [-1]) num_gt = tf.shape(gt_class)[0] gt_voxel_instance_ids = tf.reshape( inputs[standard_fields.InputDataFields.object_instance_id_voxels], [-1]) gt_instance_masks = tf.transpose( tf.one_hot(gt_voxel_instance_ids - 1, depth=num_gt, dtype=tf.float32)) for c in self.class_range: gt_mask_c = tf.equal(gt_class, c) num_gt_c = tf.math.reduce_sum(tf.cast(gt_mask_c, dtype=tf.int32)) gt_instance_masks_c = tf.boolean_mask(gt_instance_masks, gt_mask_c) detections_mask_c = tf.equal(detections_class, c) num_detections_c = tf.math.reduce_sum( tf.cast(detections_mask_c, dtype=tf.int32)) if num_detections_c == 0: continue det_scores_c = tf.boolean_mask(detections_score, detections_mask_c) det_instance_mask_c = tf.boolean_mask(detections_instance_mask, detections_mask_c) det_scores_c, sorted_indices = tf.math.top_k( det_scores_c, k=num_detections_c) det_instance_mask_c = tf.gather(det_instance_mask_c, sorted_indices) tp_c = tf.zeros([num_detections_c], dtype=tf.int32) if num_gt_c > 0: ious_c = instance_segmentation_utils.points_mask_iou( masks1=gt_instance_masks_c, masks2=det_instance_mask_c) max_overlap_gt_ids = tf.cast( tf.math.argmax(ious_c, axis=0), dtype=tf.int32) is_gt_box_detected = tf.zeros([num_gt_c], dtype=tf.int32) for i in tf.range(num_detections_c): gt_id = max_overlap_gt_ids[i] if (ious_c[gt_id, i] > self.iou_threshold and is_gt_box_detected[gt_id] == 0): tp_c = tf.maximum( tf.one_hot(i, num_detections_c, dtype=tf.int32), tp_c) is_gt_box_detected = tf.maximum( tf.one_hot(gt_id, num_gt_c, dtype=tf.int32), is_gt_box_detected) self.tp[c] = tf.concat([self.tp[c], tp_c], axis=0) self.scores[c] = tf.concat([self.scores[c], det_scores_c], axis=0) self.num_gt[c] += num_gt_c return tf.no_op()
def randomly_crop_points(mesh_inputs, view_indices_2d_inputs, x_random_crop_size, y_random_crop_size, epsilon=1e-5): """Randomly crops points. Args: mesh_inputs: A dictionary containing input mesh (point) tensors. view_indices_2d_inputs: A dictionary containing input point to view correspondence tensors. x_random_crop_size: Size of the random crop in x dimension. If None, random crop will not take place on x dimension. y_random_crop_size: Size of the random crop in y dimension. If None, random crop will not take place on y dimension. epsilon: Epsilon (a very small value) used to add as a small margin to thresholds. """ if x_random_crop_size is None and y_random_crop_size is None: return points = mesh_inputs[standard_fields.InputDataFields.point_positions] num_points = tf.shape(points)[0] # Pick a random point if x_random_crop_size is not None or y_random_crop_size is not None: random_index = tf.random.uniform([], minval=0, maxval=num_points, dtype=tf.int32) center_x = points[random_index, 0] center_y = points[random_index, 1] points_x = points[:, 0] points_y = points[:, 1] min_x = tf.reduce_min(points_x) - epsilon max_x = tf.reduce_max(points_x) + epsilon min_y = tf.reduce_min(points_y) - epsilon max_y = tf.reduce_max(points_y) + epsilon if x_random_crop_size is not None: min_x = center_x - x_random_crop_size / 2.0 - epsilon max_x = center_x + x_random_crop_size / 2.0 + epsilon if y_random_crop_size is not None: min_y = center_y - y_random_crop_size / 2.0 - epsilon max_y = center_y + y_random_crop_size / 2.0 + epsilon x_mask = tf.logical_and(tf.greater(points_x, min_x), tf.less(points_x, max_x)) y_mask = tf.logical_and(tf.greater(points_y, min_y), tf.less(points_y, max_y)) points_mask = tf.logical_and(x_mask, y_mask) for key in sorted(mesh_inputs): mesh_inputs[key] = tf.boolean_mask(mesh_inputs[key], points_mask) for key in sorted(view_indices_2d_inputs): view_indices_2d_inputs[key] = tf.transpose( tf.boolean_mask( tf.transpose(view_indices_2d_inputs[key], [1, 0, 2]), points_mask), [1, 0, 2])
def _prepare_lidar_points(inputs, lidar_names): """Integrates and returns the lidar points in vehicle coordinate frame.""" points_position = [] points_intensity = [] points_elongation = [] points_normal = [] points_in_image_frame_xy = [] points_in_image_frame_id = [] for lidar_name in lidar_names: lidar_location = tf.reshape( inputs[('lidars/%s/extrinsics/t') % lidar_name], [-1, 3]) inside_no_label_zone = tf.reshape( inputs[('lidars/%s/pointcloud/inside_nlz' % lidar_name)], [-1]) valid_points_mask = tf.math.logical_not(inside_no_label_zone) points_position_current_lidar = tf.boolean_mask( inputs[('lidars/%s/pointcloud/positions' % lidar_name)], valid_points_mask) points_position.append(points_position_current_lidar) points_intensity.append( tf.boolean_mask( inputs[('lidars/%s/pointcloud/intensity' % lidar_name)], valid_points_mask)) points_elongation.append( tf.boolean_mask( inputs[('lidars/%s/pointcloud/elongation' % lidar_name)], valid_points_mask)) points_to_lidar_vectors = lidar_location - points_position_current_lidar points_normal_direction = points_to_lidar_vectors / tf.expand_dims( tf.norm(points_to_lidar_vectors, axis=1), axis=1) points_normal.append(points_normal_direction) points_in_image_frame_xy.append( tf.boolean_mask( inputs['lidars/%s/camera_projections/positions' % lidar_name], valid_points_mask)) points_in_image_frame_id.append( tf.boolean_mask( inputs['lidars/%s/camera_projections/ids' % lidar_name], valid_points_mask)) points_position = tf.concat(points_position, axis=0) points_intensity = tf.concat(points_intensity, axis=0) points_elongation = tf.concat(points_elongation, axis=0) points_normal = tf.concat(points_normal, axis=0) points_in_image_frame_xy = tf.concat(points_in_image_frame_xy, axis=0) points_in_image_frame_id = tf.cast(tf.concat(points_in_image_frame_id, axis=0), dtype=tf.int32) points_in_image_frame_yx = tf.cast(tf.reverse(points_in_image_frame_xy, axis=[-1]), dtype=tf.int32) return (points_position, points_intensity, points_elongation, points_normal, points_in_image_frame_yx, points_in_image_frame_id)
def convert_to_simclr_episode(support_images=None, support_labels=None, support_class_ids=None, query_images=None, query_labels=None, query_class_ids=None): """Convert a single episode into a SimCLR Episode.""" # If there were k query examples of class c, keep the first k support # examples of class c as 'simclr' queries. We do this by assigning an # id for each image in the query set, implemented as label*1e5+x+1, where # x is the number of images of the same label with a lower index within # the query set. We do the same for the support set, which gives us a # mapping between query and support images which is injective (as long # as there's enough support-set images of each class). # # note: assumes max support label is 10000 - max_images_per_class query_idx_within_class = tf.cast( tf.equal(query_labels[tf.newaxis, :], query_labels[:, tf.newaxis]), tf.int32) query_idx_within_class = tf.linalg.diag_part( tf.cumsum(query_idx_within_class, axis=1)) query_uid = query_labels * 10000 + query_idx_within_class support_idx_within_class = tf.cast( tf.equal(support_labels[tf.newaxis, :], support_labels[:, tf.newaxis]), tf.int32) support_idx_within_class = tf.linalg.diag_part( tf.cumsum(support_idx_within_class, axis=1)) support_uid = support_labels * 10000 + support_idx_within_class # compute which support-set images have matches in the query set, and # discard the rest to produce the new query set. support_keep = tf.reduce_any(tf.equal(support_uid[:, tf.newaxis], query_uid[tf.newaxis, :]), axis=1) query_images = tf.boolean_mask(support_images, support_keep) support_labels = tf.range(tf.shape(support_labels)[0], dtype=support_labels.dtype) query_labels = tf.boolean_mask(support_labels, support_keep) query_class_ids = tf.boolean_mask(support_class_ids, support_keep) # Finally, apply SimCLR augmentation to all images. # Note simclr only blurs one image. query_images = simclr_augment(query_images, blur=True) support_images = simclr_augment(support_images) return (support_images, support_labels, support_class_ids, query_images, query_labels, query_class_ids)
def _remove_second_return_lidar_points(mesh_inputs, view_indices_2d_inputs): """removes the points that are not lidar first-return .""" if standard_fields.InputDataFields.point_spin_coordinates not in mesh_inputs: raise ValueError('spin_coordinates not in mesh_inputs.') first_return_mask = tf.equal( tf.cast(mesh_inputs[ standard_fields.InputDataFields.point_spin_coordinates][:, 2], dtype=tf.int32), 0) for key in sorted(mesh_inputs): mesh_inputs[key] = tf.boolean_mask(mesh_inputs[key], first_return_mask) for key in sorted(view_indices_2d_inputs): view_indices_2d_inputs[key] = tf.transpose( tf.boolean_mask( tf.transpose(view_indices_2d_inputs[key], [1, 0, 2]), first_return_mask), [1, 0, 2])
def _filter_valid_objects(inputs): """Removes the objects that do not contain 3d info. Args: inputs: A dictionary containing input tensors. """ if standard_fields.InputDataFields.objects_class not in inputs: return valid_objects_mask = tf.reshape( tf.greater(inputs[standard_fields.InputDataFields.objects_class], 0), [-1]) if standard_fields.InputDataFields.objects_has_3d_info in inputs: objects_with_3d_info = tf.reshape( tf.cast( inputs[standard_fields.InputDataFields.objects_has_3d_info], dtype=tf.bool), [-1]) valid_objects_mask = tf.logical_and(objects_with_3d_info, valid_objects_mask) if standard_fields.InputDataFields.objects_difficulty in inputs: valid_objects_mask = tf.logical_and( valid_objects_mask, tf.greater( tf.reshape( inputs[standard_fields.InputDataFields.objects_difficulty], [-1]), 0)) for key in _OBJECT_KEYS: if key in inputs: inputs[key] = tf.boolean_mask(inputs[key], valid_objects_mask)
def experience_to_transitions(experience): boundary_mask = tf.logical_not(experience.is_boundary()[:, 0]) experience = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, boundary_mask), experience) time_steps, policy_steps, next_time_steps = ( trajectory.experience_to_transitions(experience, True)) actions = policy_steps.action return time_steps, actions, next_time_steps
def embedding_regularization_loss(inputs, outputs, lambda_coef=0.0001, regularization_type='unit_length', is_intermediate=False): """Classification loss with an iou threshold. Args: inputs: A dictionary that contains num_valid_voxels - A tf.int32 tensor of size [batch_size]. instance_ids - A tf.int32 tensor of size [batch_size, n]. outputs: A dictionart that contains embeddings - A tf.float32 tensor of size [batch_size, n, f]. lambda_coef: Regularization loss coefficient. regularization_type: Regularization loss type. Supported values are 'msq' and 'unit_length'. 'msq' stands for 'mean square' which penalizes the embedding vectors if they have a length far from zero. 'unit_length' penalizes the embedding vectors if they have a length far from one. is_intermediate: True if applied to intermediate predictions; otherwise, False. Returns: A tf.float32 scalar loss tensor. """ instance_ids_key = standard_fields.InputDataFields.object_instance_id_voxels num_voxels_key = standard_fields.InputDataFields.num_valid_voxels if is_intermediate: embedding_key = ( standard_fields.DetectionResultFields .intermediate_instance_embedding_voxels) else: embedding_key = ( standard_fields.DetectionResultFields.instance_embedding_voxels) if instance_ids_key not in inputs: raise ValueError('instance_ids is missing in inputs.') if embedding_key not in outputs: raise ValueError('embedding is missing in outputs.') if num_voxels_key not in inputs: raise ValueError('num_voxels is missing in inputs.') batch_size = inputs[num_voxels_key].get_shape().as_list()[0] if batch_size is None: raise ValueError('batch_size is not defined at graph construction time.') num_valid_voxels = inputs[num_voxels_key] num_voxels = tf.shape(inputs[instance_ids_key])[1] valid_mask = tf.less( tf.tile(tf.expand_dims(tf.range(num_voxels), axis=0), [batch_size, 1]), tf.expand_dims(num_valid_voxels, axis=1)) valid_mask = tf.reshape(valid_mask, [-1]) embedding_dims = outputs[embedding_key].get_shape().as_list()[-1] if embedding_dims is None: raise ValueError( 'Embedding dimension is unknown at graph construction time.') embedding = tf.reshape(outputs[embedding_key], [-1, embedding_dims]) embedding = tf.boolean_mask(embedding, valid_mask) return metric_learning_losses.regularization_loss( embedding=embedding, lambda_coef=lambda_coef, regularization_type=regularization_type)
def compute_target_optimal_q(reward, gamma, next_actions, next_q_values, next_states, terminals): """Builds an op used as a target for the Q-value. This algorithm corresponds to the method "OT" in Ie et al. https://arxiv.org/abs/1905.12767.. Args: reward: [batch_size] tensor, the immediate reward. gamma: float, discount factor with the usual RL meaning. next_actions: [batch_size, slate_size] tensor, the next slate. next_q_values: [batch_size, num_of_documents] tensor, the q values of the documents in the next step. next_states: [batch_size, 1 + num_of_documents] tensor, the features for the user and the docuemnts in the next step. terminals: [batch_size] tensor, indicating if this is a terminal step. Returns: [batch_size] tensor, the target q values. """ scores, score_no_click = _get_unnormalized_scores(next_states) # Obtain all possible slates given current docs in the candidate set. slate_size = next_actions.get_shape().as_list()[1] num_candidates = next_q_values.get_shape().as_list()[1] mesh_args = [list(range(num_candidates))] * slate_size slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1) slates = tf.reshape(slates, shape=(-1, slate_size)) # Filter slates that include duplicates to ensure each document is picked # at most once. unique_mask = tf.map_fn( lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])), slates, dtype=tf.bool) # [num_of_slates, slate_size] slates = tf.boolean_mask(tensor=slates, mask=unique_mask) # [batch_size, num_of_slates, slate_size] next_q_values_slate = tf.gather(next_q_values, slates, axis=1) # [batch_size, num_of_slates, slate_size] scores_slate = tf.gather(scores, slates, axis=1) # [batch_size, num_of_slates] batch_size = next_states.get_shape().as_list()[0] score_no_click_slate = tf.reshape( tf.tile(score_no_click, tf.shape(input=slates)[:1]), [batch_size, -1]) # [batch_size, num_of_slates] next_q_target_slate = tf.reduce_sum( input_tensor=next_q_values_slate * scores_slate, axis=2) / (tf.reduce_sum(input_tensor=scores_slate, axis=2) + score_no_click_slate) next_q_target_max = tf.reduce_max(input_tensor=next_q_target_slate, axis=1) return reward + gamma * next_q_target_max * ( 1. - tf.cast(terminals, tf.float32))
def loss_fn(): """Loss function.""" num_classes = logits.get_shape().as_list()[-1] if num_classes is None: raise ValueError('Number of classes is unknown.') masked_logits = tf.boolean_mask(logits, background_mask) masked_weights = tf.pow( 1.0 - tf.reshape(tf.nn.softmax(masked_logits)[:, 0], [-1, 1]), gamma) num_points = tf.shape(masked_logits)[0] masked_weights = masked_weights * tf.cast( num_points, dtype=tf.float32) / tf.reduce_sum(masked_weights) masked_labels_one_hot = tf.one_hot(indices=tf.boolean_mask( labels, background_mask), depth=num_classes) loss = classification_loss_fn(logits=masked_logits, labels=masked_labels_one_hot, weights=masked_weights) return loss
def _body_fn(i, indices_range, indices): """Computes the indices of the i-th point feature in each segment.""" indices_i = tf.math.unsorted_segment_max( data=indices_range, segment_ids=segment_ids, num_segments=num_segments) indices_i_positive_mask = tf.greater(indices_i, 0) indices_i_positive = tf.boolean_mask(indices_i, indices_i_positive_mask) boolean_mask = tf.scatter_nd( indices=tf.cast( tf.expand_dims(indices_i_positive - 1, axis=1), dtype=tf.int64), updates=tf.ones_like(indices_i_positive, dtype=tf.int32), shape=(n,)) indices_range *= (1 - boolean_mask) indices_i *= tf.cast(indices_i_positive_mask, dtype=tf.int32) indices_i = tf.pad( tf.expand_dims(indices_i, axis=1), paddings=[[0, 0], [i, num_samples_per_voxel - i - 1]]) indices += indices_i i = i + 1 return i, indices_range, indices
def select_slate_optimal(slate_size, s_no_click, s, q): """Selects the slate using exhaustive search. This algorithm corresponds to the method "OS" in Ie et al. https://arxiv.org/abs/1905.12767. Args: slate_size: int, the size of the recommendation slate. s_no_click: float tensor, the score for not clicking any document. s: [num_of_documents] tensor, the scores for clicking documents. q: [num_of_documents] tensor, the predicted q values for documents. Returns: [slate_size] tensor, the selected slate. """ num_candidates = s.shape.as_list()[0] # Obtain all possible slates given current docs in the candidate set. mesh_args = [list(range(num_candidates))] * slate_size slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1) slates = tf.reshape(slates, shape=(-1, slate_size)) # Filter slates that include duplicates to ensure each document is picked # at most once. unique_mask = tf.map_fn( lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])), slates, dtype=tf.bool) slates = tf.boolean_mask(tensor=slates, mask=unique_mask) slate_q_values = tf.gather(s * q, slates) slate_scores = tf.gather(s, slates) slate_normalizer = tf.reduce_sum(input_tensor=slate_scores, axis=1) + s_no_click slate_q_values = slate_q_values / tf.expand_dims(slate_normalizer, 1) slate_sum_q_values = tf.reduce_sum(input_tensor=slate_q_values, axis=1) max_q_slate_index = tf.argmax(input=slate_sum_q_values) return tf.gather(slates, max_q_slate_index, axis=0)
def _box_classification_loss_unbatched(inputs_1, outputs_1, is_intermediate, is_balanced, mine_hard_negatives, hard_negative_score_threshold): """Loss function for input and outputs of batch size 1.""" valid_mask = _get_voxels_valid_mask(inputs_1=inputs_1) if is_intermediate: logits = outputs_1[standard_fields.DetectionResultFields. intermediate_object_semantic_voxels] else: logits = outputs_1[ standard_fields.DetectionResultFields.object_semantic_voxels] num_classes = logits.get_shape().as_list()[-1] if num_classes is None: raise ValueError('Number of classes is unknown.') logits = tf.boolean_mask(tf.reshape(logits, [-1, num_classes]), valid_mask) labels = tf.boolean_mask( tf.reshape( inputs_1[standard_fields.InputDataFields.object_class_voxels], [-1, 1]), valid_mask) if mine_hard_negatives or is_balanced: instances = tf.boolean_mask( tf.reshape( inputs_1[ standard_fields.InputDataFields.object_instance_id_voxels], [-1]), valid_mask) params = {} if mine_hard_negatives: negative_scores = tf.reshape(tf.nn.softmax(logits)[:, 0], [-1]) hard_negative_mask = tf.logical_and( tf.less(negative_scores, hard_negative_score_threshold), tf.equal(tf.reshape(labels, [-1]), 0)) hard_negative_labels = tf.boolean_mask(labels, hard_negative_mask) hard_negative_logits = tf.boolean_mask(logits, hard_negative_mask) hard_negative_instances = tf.boolean_mask( tf.ones_like(instances) * (tf.reduce_max(instances) + 1), hard_negative_mask) logits = tf.concat([logits, hard_negative_logits], axis=0) instances = tf.concat([instances, hard_negative_instances], axis=0) labels = tf.concat([labels, hard_negative_labels], axis=0) if is_balanced: weights = loss_utils.get_balanced_loss_weights_multiclass( labels=tf.expand_dims(instances, axis=1)) params['weights'] = weights return classification_loss_fn(logits=logits, labels=labels, **params)
def prepare_kitti_dataset(inputs, valid_object_classes=None): """Maps the fields from loaded input to standard fields. Args: inputs: A dictionary of input tensors. valid_object_classes: List of valid object classes. if None, it is ignored. Returns: A dictionary of input tensors with standard field names. """ prepared_inputs = {} prepared_inputs[standard_fields.InputDataFields.point_positions] = inputs[ standard_fields.InputDataFields.point_positions] prepared_inputs[standard_fields.InputDataFields.point_intensities] = inputs[ standard_fields.InputDataFields.point_intensities] prepared_inputs[standard_fields.InputDataFields .camera_intrinsics] = inputs['cameras/cam02/intrinsics/K'] prepared_inputs[standard_fields.InputDataFields. camera_rotation_matrix] = inputs['cameras/cam02/extrinsics/R'] prepared_inputs[standard_fields.InputDataFields .camera_translation] = inputs['cameras/cam02/extrinsics/t'] prepared_inputs[standard_fields.InputDataFields .camera_image] = inputs['cameras/cam02/image'] prepared_inputs[standard_fields.InputDataFields .camera_raw_image] = inputs['cameras/cam02/image'] prepared_inputs[standard_fields.InputDataFields .camera_original_image] = inputs['cameras/cam02/image'] if 'scene_name' in inputs and 'frame_name' in inputs: prepared_inputs[ standard_fields.InputDataFields.camera_image_name] = tf.strings.join( [inputs['scene_name'], inputs['frame_name']], separator='_') if 'objects/pose/R' in inputs: prepared_inputs[standard_fields.InputDataFields .objects_rotation_matrix] = inputs['objects/pose/R'] if 'objects/pose/t' in inputs: prepared_inputs[standard_fields.InputDataFields .objects_center] = inputs['objects/pose/t'] if 'objects/shape/dimension' in inputs: prepared_inputs[ standard_fields.InputDataFields.objects_length] = tf.reshape( inputs['objects/shape/dimension'][:, 0], [-1, 1]) prepared_inputs[standard_fields.InputDataFields.objects_width] = tf.reshape( inputs['objects/shape/dimension'][:, 1], [-1, 1]) prepared_inputs[ standard_fields.InputDataFields.objects_height] = tf.reshape( inputs['objects/shape/dimension'][:, 2], [-1, 1]) if 'objects/category/label' in inputs: prepared_inputs[standard_fields.InputDataFields.objects_class] = tf.reshape( inputs['objects/category/label'], [-1, 1]) if valid_object_classes is not None: valid_objects_mask = tf.cast( tf.zeros_like( prepared_inputs[standard_fields.InputDataFields.objects_class], dtype=tf.int32), dtype=tf.bool) for object_class in valid_object_classes: valid_objects_mask = tf.logical_or( valid_objects_mask, tf.equal( prepared_inputs[standard_fields.InputDataFields.objects_class], object_class)) valid_objects_mask = tf.reshape(valid_objects_mask, [-1]) for key in standard_fields.get_input_object_fields(): if key in prepared_inputs: prepared_inputs[key] = tf.boolean_mask(prepared_inputs[key], valid_objects_mask) return prepared_inputs
def mask_tensor(x, s): not_x = tf.boolean_mask(x, tf.logical_not(s)) x = tf.boolean_mask(x, s) return x, not_x
def prepare_waymo_open_dataset(inputs, valid_object_classes=None, max_object_distance_from_source=74.88): """Maps the fields from loaded input to standard fields. Args: inputs: A dictionary of input tensors. valid_object_classes: List of valid object classes. if None, it is ignored. max_object_distance_from_source: Maximum distance of objects from source. It will be ignored if None. Returns: A dictionary of input tensors with standard field names. """ prepared_inputs = {} if standard_fields.InputDataFields.point_positions in inputs: prepared_inputs[standard_fields.InputDataFields.point_positions] = inputs[ standard_fields.InputDataFields.point_positions] if standard_fields.InputDataFields.point_intensities in inputs: prepared_inputs[standard_fields.InputDataFields.point_intensities] = inputs[ standard_fields.InputDataFields.point_intensities] if standard_fields.InputDataFields.point_elongations in inputs: prepared_inputs[standard_fields.InputDataFields.point_elongations] = inputs[ standard_fields.InputDataFields.point_elongations] if standard_fields.InputDataFields.point_normals in inputs: prepared_inputs[standard_fields.InputDataFields.point_normals] = inputs[ standard_fields.InputDataFields.point_normals] if 'cameras/front/intrinsics/K' in inputs: prepared_inputs[standard_fields.InputDataFields .camera_intrinsics] = inputs['cameras/front/intrinsics/K'] if 'cameras/front/extrinsics/R' in inputs: prepared_inputs[ standard_fields.InputDataFields .camera_rotation_matrix] = inputs['cameras/front/extrinsics/R'] if 'cameras/front/extrinsics/t' in inputs: prepared_inputs[standard_fields.InputDataFields .camera_translation] = inputs['cameras/front/extrinsics/t'] if 'cameras/front/image' in inputs: prepared_inputs[standard_fields.InputDataFields .camera_image] = inputs['cameras/front/image'] prepared_inputs[standard_fields.InputDataFields .camera_raw_image] = inputs['cameras/front/image'] prepared_inputs[standard_fields.InputDataFields .camera_original_image] = inputs['cameras/front/image'] if 'scene_name' in inputs and 'frame_name' in inputs: prepared_inputs[ standard_fields.InputDataFields.camera_image_name] = tf.strings.join( [inputs['scene_name'], inputs['frame_name']], separator='_') if 'objects/pose/R' in inputs: prepared_inputs[standard_fields.InputDataFields .objects_rotation_matrix] = inputs['objects/pose/R'] if 'objects/pose/t' in inputs: prepared_inputs[standard_fields.InputDataFields .objects_center] = inputs['objects/pose/t'] if 'objects/shape/dimension' in inputs: prepared_inputs[ standard_fields.InputDataFields.objects_length] = tf.reshape( inputs['objects/shape/dimension'][:, 0], [-1, 1]) prepared_inputs[standard_fields.InputDataFields.objects_width] = tf.reshape( inputs['objects/shape/dimension'][:, 1], [-1, 1]) prepared_inputs[ standard_fields.InputDataFields.objects_height] = tf.reshape( inputs['objects/shape/dimension'][:, 2], [-1, 1]) if 'objects/category/label' in inputs: prepared_inputs[standard_fields.InputDataFields.objects_class] = tf.reshape( inputs['objects/category/label'], [-1, 1]) if valid_object_classes is not None: valid_objects_mask = tf.cast( tf.zeros_like( prepared_inputs[standard_fields.InputDataFields.objects_class], dtype=tf.int32), dtype=tf.bool) for object_class in valid_object_classes: valid_objects_mask = tf.logical_or( valid_objects_mask, tf.equal( prepared_inputs[standard_fields.InputDataFields.objects_class], object_class)) valid_objects_mask = tf.reshape(valid_objects_mask, [-1]) for key in standard_fields.get_input_object_fields(): if key in prepared_inputs: prepared_inputs[key] = tf.boolean_mask(prepared_inputs[key], valid_objects_mask) if max_object_distance_from_source is not None: if standard_fields.InputDataFields.objects_center in prepared_inputs: object_distances = tf.norm( prepared_inputs[standard_fields.InputDataFields.objects_center][:, 0:2], axis=1) valid_mask = tf.less(object_distances, max_object_distance_from_source) for key in standard_fields.get_input_object_fields(): if key in prepared_inputs: prepared_inputs[key] = tf.boolean_mask(prepared_inputs[key], valid_mask) return prepared_inputs
def prepare_scannet_frame_dataset(inputs, min_pixel_depth=0.3, max_pixel_depth=6.0, valid_object_classes=None): """Maps the fields from loaded input to standard fields. Args: inputs: A dictionary of input tensors. min_pixel_depth: Pixels with depth values less than this are pruned. max_pixel_depth: Pixels with depth values more than this are pruned. valid_object_classes: List of valid object classes. if None, it is ignored. Returns: A dictionary of input tensors with standard field names. """ prepared_inputs = {} if 'cameras/rgbd_camera/intrinsics/K' not in inputs: raise ValueError('Intrinsic matrix is missing.') if 'cameras/rgbd_camera/extrinsics/R' not in inputs: raise ValueError('Extrinsic rotation matrix is missing.') if 'cameras/rgbd_camera/extrinsics/t' not in inputs: raise ValueError('Extrinsics translation is missing.') if 'cameras/rgbd_camera/depth_image' not in inputs: raise ValueError('Depth image is missing.') if 'cameras/rgbd_camera/color_image' not in inputs: raise ValueError('Color image is missing.') if 'frame_name' in inputs: prepared_inputs[standard_fields.InputDataFields .camera_image_name] = inputs['frame_name'] camera_intrinsics = inputs['cameras/rgbd_camera/intrinsics/K'] depth_image = inputs['cameras/rgbd_camera/depth_image'] image_height = tf.shape(depth_image)[0] image_width = tf.shape(depth_image)[1] x, y = tf.meshgrid( tf.range(image_width), tf.range(image_height), indexing='xy') x = tf.reshape(tf.cast(x, dtype=tf.float32) + 0.5, [-1, 1]) y = tf.reshape(tf.cast(y, dtype=tf.float32) + 0.5, [-1, 1]) point_positions = projections.image_frame_to_camera_frame( image_frame=tf.concat([x, y], axis=1), camera_intrinsics=camera_intrinsics) rotate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/R'] translate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/t'] point_positions = projections.to_world_frame( camera_frame_points=point_positions, rotate_world_to_camera=rotate_world_to_camera, translate_world_to_camera=translate_world_to_camera) prepared_inputs[standard_fields.InputDataFields .point_positions] = point_positions * tf.reshape( depth_image, [-1, 1]) depth_values = tf.reshape(depth_image, [-1]) valid_depth_mask = tf.logical_and( tf.greater_equal(depth_values, min_pixel_depth), tf.less_equal(depth_values, max_pixel_depth)) prepared_inputs[standard_fields.InputDataFields.point_colors] = tf.reshape( tf.cast(inputs['cameras/rgbd_camera/color_image'], dtype=tf.float32), [-1, 3]) prepared_inputs[standard_fields.InputDataFields.point_colors] *= (2.0 / 255.0) prepared_inputs[standard_fields.InputDataFields.point_colors] -= 1.0 prepared_inputs[ standard_fields.InputDataFields.point_positions] = tf.boolean_mask( prepared_inputs[standard_fields.InputDataFields.point_positions], valid_depth_mask) prepared_inputs[ standard_fields.InputDataFields.point_colors] = tf.boolean_mask( prepared_inputs[standard_fields.InputDataFields.point_colors], valid_depth_mask) if 'cameras/rgbd_camera/semantic_image' in inputs: prepared_inputs[ standard_fields.InputDataFields.object_class_points] = tf.cast( tf.reshape(inputs['cameras/rgbd_camera/semantic_image'], [-1, 1]), dtype=tf.int32) prepared_inputs[ standard_fields.InputDataFields.object_class_points] = tf.boolean_mask( prepared_inputs[ standard_fields.InputDataFields.object_class_points], valid_depth_mask) if 'cameras/rgbd_camera/instance_image' in inputs: prepared_inputs[ standard_fields.InputDataFields.object_instance_id_points] = tf.cast( tf.reshape(inputs['cameras/rgbd_camera/instance_image'], [-1]), dtype=tf.int32) prepared_inputs[standard_fields.InputDataFields .object_instance_id_points] = tf.boolean_mask( prepared_inputs[standard_fields.InputDataFields .object_instance_id_points], valid_depth_mask) if valid_object_classes is not None: valid_objects_mask = tf.cast( tf.zeros_like( prepared_inputs[ standard_fields.InputDataFields.object_class_points], dtype=tf.int32), dtype=tf.bool) for object_class in valid_object_classes: valid_objects_mask = tf.logical_or( valid_objects_mask, tf.equal( prepared_inputs[ standard_fields.InputDataFields.object_class_points], object_class)) valid_objects_mask = tf.cast( valid_objects_mask, dtype=prepared_inputs[ standard_fields.InputDataFields.object_class_points].dtype) prepared_inputs[standard_fields.InputDataFields .object_class_points] *= valid_objects_mask return prepared_inputs
def classification_loss_using_mask_iou_func(embeddings, logits, instance_ids, class_labels, num_samples, valid_mask=None, max_instance_id=None, similarity_strategy='dotproduct', is_balanced=True): """Classification loss using mask iou. Args: embeddings: A tf.float32 tensor of size [batch_size, n, f]. logits: A tf.float32 tensor of size [batch_size, n, num_classes]. It is assumed that background is class 0. instance_ids: A tf.int32 tensor of size [batch_size, n]. class_labels: A tf.int32 tensor of size [batch_size, n]. It is assumed that the background voxels are assigned to class 0. num_samples: An int determining the number of samples. valid_mask: A tf.bool tensor of size [batch_size, n] that is True when an element is valid and False if it needs to be ignored. By default the value is None which means it is not applied. max_instance_id: If set, instance ids larger than that value will be ignored. If not set, it will be computed from instance_ids tensor. similarity_strategy: Defines the method for computing similarity between embedding vectors. Possible values are 'dotproduct' and 'distance'. is_balanced: If True, the per-voxel losses are re-weighted to have equal total weight for foreground vs. background voxels. Returns: A tf.float32 scalar loss tensor. """ batch_size = embeddings.get_shape().as_list()[0] if batch_size is None: raise ValueError('Unknown batch size at graph construction time.') if max_instance_id is None: max_instance_id = tf.reduce_max(instance_ids) class_labels = tf.reshape(class_labels, [batch_size, -1, 1]) sampled_embeddings, sampled_instance_ids, sampled_indices = ( sampling_utils.balanced_sample(features=embeddings, instance_ids=instance_ids, num_samples=num_samples, valid_mask=valid_mask, max_instance_id=max_instance_id)) losses = [] for i in range(batch_size): embeddings_i = embeddings[i, :, :] instance_ids_i = instance_ids[i, :] class_labels_i = class_labels[i, :, :] logits_i = logits[i, :] sampled_embeddings_i = sampled_embeddings[i, :, :] sampled_instance_ids_i = sampled_instance_ids[i, :] sampled_indices_i = sampled_indices[i, :] sampled_class_labels_i = tf.gather(class_labels_i, sampled_indices_i) sampled_logits_i = tf.gather(logits_i, sampled_indices_i) if valid_mask is not None: valid_mask_i = valid_mask[i] embeddings_i = tf.boolean_mask(embeddings_i, valid_mask_i) instance_ids_i = tf.boolean_mask(instance_ids_i, valid_mask_i) loss_i = classification_loss_using_mask_iou_func_unbatched( embeddings=embeddings_i, instance_ids=instance_ids_i, sampled_embeddings=sampled_embeddings_i, sampled_instance_ids=sampled_instance_ids_i, sampled_class_labels=sampled_class_labels_i, sampled_logits=sampled_logits_i, similarity_strategy=similarity_strategy, is_balanced=is_balanced) losses.append(loss_i) return tf.math.reduce_mean(tf.stack(losses))
def update_state(self, inputs, outputs): """Function that updates the metric state at each example. Args: inputs: A dictionary containing input tensors. outputs: A dictionary containing output tensors. Returns: Update op. """ detections_score = tf.reshape( outputs[standard_fields.DetectionResultFields.objects_score], [-1]) detections_class = tf.reshape( outputs[standard_fields.DetectionResultFields.objects_class], [-1]) detections_length = tf.reshape( outputs[standard_fields.DetectionResultFields.objects_length], [-1]) detections_height = tf.reshape( outputs[standard_fields.DetectionResultFields.objects_height], [-1]) detections_width = tf.reshape( outputs[standard_fields.DetectionResultFields.objects_width], [-1]) detections_center = tf.reshape( outputs[standard_fields.DetectionResultFields.objects_center], [-1, 3]) detections_rotation_matrix = tf.reshape( outputs[ standard_fields.DetectionResultFields.objects_rotation_matrix], [-1, 3, 3]) gt_class = tf.reshape( inputs[standard_fields.InputDataFields.objects_class], [-1]) gt_length = tf.reshape( inputs[standard_fields.InputDataFields.objects_length], [-1]) gt_height = tf.reshape( inputs[standard_fields.InputDataFields.objects_height], [-1]) gt_width = tf.reshape( inputs[standard_fields.InputDataFields.objects_width], [-1]) gt_center = tf.reshape( inputs[standard_fields.InputDataFields.objects_center], [-1, 3]) gt_rotation_matrix = tf.reshape( inputs[standard_fields.InputDataFields.objects_rotation_matrix], [-1, 3, 3]) for c in self.class_range: gt_mask_c = tf.equal(gt_class, c) num_gt_c = tf.math.reduce_sum(tf.cast(gt_mask_c, dtype=tf.int32)) gt_length_c = tf.boolean_mask(gt_length, gt_mask_c) gt_height_c = tf.boolean_mask(gt_height, gt_mask_c) gt_width_c = tf.boolean_mask(gt_width, gt_mask_c) gt_center_c = tf.boolean_mask(gt_center, gt_mask_c) gt_rotation_matrix_c = tf.boolean_mask(gt_rotation_matrix, gt_mask_c) detections_mask_c = tf.equal(detections_class, c) num_detections_c = tf.math.reduce_sum( tf.cast(detections_mask_c, dtype=tf.int32)) if num_detections_c == 0: continue det_length_c = tf.boolean_mask(detections_length, detections_mask_c) det_height_c = tf.boolean_mask(detections_height, detections_mask_c) det_width_c = tf.boolean_mask(detections_width, detections_mask_c) det_center_c = tf.boolean_mask(detections_center, detections_mask_c) det_rotation_matrix_c = tf.boolean_mask(detections_rotation_matrix, detections_mask_c) det_scores_c = tf.boolean_mask(detections_score, detections_mask_c) det_scores_c, sorted_indices = tf.math.top_k(det_scores_c, k=num_detections_c) det_length_c = tf.gather(det_length_c, sorted_indices) det_height_c = tf.gather(det_height_c, sorted_indices) det_width_c = tf.gather(det_width_c, sorted_indices) det_center_c = tf.gather(det_center_c, sorted_indices) det_rotation_matrix_c = tf.gather(det_rotation_matrix_c, sorted_indices) tp_c = tf.zeros([num_detections_c], dtype=tf.int32) if num_gt_c > 0: ious_c = box_ops.iou3d( boxes1_length=gt_length_c, boxes1_height=gt_height_c, boxes1_width=gt_width_c, boxes1_center=gt_center_c, boxes1_rotation_matrix=gt_rotation_matrix_c, boxes2_length=det_length_c, boxes2_height=det_height_c, boxes2_width=det_width_c, boxes2_center=det_center_c, boxes2_rotation_matrix=det_rotation_matrix_c) max_overlap_gt_ids = tf.cast(tf.math.argmax(ious_c, axis=0), dtype=tf.int32) is_gt_box_detected = tf.zeros([num_gt_c], dtype=tf.int32) for i in tf.range(num_detections_c): gt_id = max_overlap_gt_ids[i] if (ious_c[gt_id, i] > self.iou_threshold and is_gt_box_detected[gt_id] == 0): tp_c = tf.maximum( tf.one_hot(i, num_detections_c, dtype=tf.int32), tp_c) is_gt_box_detected = tf.maximum( tf.one_hot(gt_id, num_gt_c, dtype=tf.int32), is_gt_box_detected) self.tp[c] = tf.concat([self.tp[c], tp_c], axis=0) self.scores[c] = tf.concat([self.scores[c], det_scores_c], axis=0) self.num_gt[c] += num_gt_c return tf.no_op()
def _non_nan_mean(tensor_list): """Calculates the mean of a list of tensors while ignoring nans.""" tensor = tf.stack(tensor_list) not_nan = tf.logical_not(tf.math.is_nan(tensor)) return tf.reduce_mean(tf.boolean_mask(tensor, not_nan))
def train_eval( load_root_dir, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args safety_critic_joint_fc_layers=None, safety_critic_lr=3e-4, safety_critic_bias_init_val=None, safety_critic_kernel_scale=None, n_envs=None, target_safety=0.2, fail_weight=None, # Params for train num_global_steps=10000, batch_size=256, # Params for eval run_eval=False, eval_metrics=[], num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, debug_summaries=False, seed=None): if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format( agent_class) agent_class = ALGOS.get(agent_class) train_ckpt_dir = osp.join(load_root_dir, 'train') rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer') py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(py_env) if monitor: vid_path = os.path.join(load_root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) if run_eval: eval_dir = os.path.join(load_root_dir, 'eval') n_envs = n_envs or num_eval_episodes eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric( prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs) ] + [ tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name)) for m in eval_metrics ] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: eval_tf_env.pyenv.seed(seeds) except: pass global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=False) else: tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=False) collect_data_spec = tf_agent.collect_data_spec replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=1000000) replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer) tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent) if agent_class in SAFETY_AGENTS: target_safety = target_safety or tf_agent._target_safety loaded_train_steps = global_step.numpy() logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir, loaded_train_steps) global_step.assign(0) tf.summary.experimental.set_step(global_step) thresholds = [target_safety, 0.5] sc_metrics = [ tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds) ] if seed: tf.compat.v1.set_random_seed(seed) summaries_flush_secs = 10 timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S') offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp) config_saver = gin.tf.GinConfigSaverHook(offline_train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() sc_summary_writer = tf.compat.v2.summary.create_file_writer( offline_train_dir, flush_millis=summaries_flush_secs * 1000) sc_summary_writer.set_as_default() if safety_critic_kernel_scale is not None: ki = tf.compat.v1.variance_scaling_initializer( scale=safety_critic_kernel_scale, mode='fan_in', distribution='truncated_normal') else: ki = tf.compat.v1.keras.initializers.VarianceScaling( scale=1. / 3., mode='fan_in', distribution='uniform') if safety_critic_bias_init_val is not None: bi = tf.constant_initializer(safety_critic_bias_init_val) else: bi = None sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=safety_critic_joint_fc_layers, kernel_initializer=ki, value_bias_initializer=bi, name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') optimizer = tf.keras.optimizers.Adam(safety_critic_lr) sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic') sc_checkpointer = common.Checkpointer( ckpt_dir=sc_net_off_ckpt_dir, safety_critic=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=optimizer, global_step=global_step, max_to_keep=5) sc_checkpointer.initialize_or_restore() resample_counter = py_metrics.CounterMetric('ActionResampleCounter') eval_policy = agents.SafeActorPolicyRSVar( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_net, safety_critic_network=sc_net_off, safety_threshold=target_safety, resample_counter=resample_counter, training=True) dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2, sample_batch_size=batch_size // 2).prefetch(3) data = iter(dataset) full_data = replay_buffer.gather_all() fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool) fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, fail_mask), full_data) init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data) before_fail_mask = tf.roll(fail_mask, [-1], axis=[1]) after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1]) before_fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data) after_init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, after_init_mask), full_data) filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask)) filter_mask = tf.pad( filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]]) n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy() failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=n_failures, dataset_window_shift=1) data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask) sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size // 2, num_steps=2).prefetch(3) neg_data = iter(sc_dataset_neg) get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0] eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step, after_init_step, get_action) losses = [] mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss') target_update = train_utils.get_target_updater(sc_net_off, target_sc_net_off) with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): while global_step.numpy() < num_global_steps: pos_experience, _ = next(data) neg_experience, _ = next(neg_data) exp = data_utils.concat_batches(pos_experience, neg_experience, collect_data_spec) boundary_mask = tf.logical_not(exp.is_boundary()[:, 0]) exp = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, boundary_mask), exp) safe_rew = exp.observation['task_agn_rew'][:, 1] if fail_weight: weights = tf.where(tf.cast(safe_rew, tf.bool), fail_weight / 0.5, (1 - fail_weight) / 0.5) else: weights = None train_loss, sc_loss, lam_loss = train_step( exp, safe_rew, tf_agent, sc_net=sc_net_off, target_sc_net=target_sc_net_off, metrics=sc_metrics, weights=weights, target_safety=target_safety, optimizer=optimizer, target_update=target_update, debug_summaries=debug_summaries) global_step.assign_add(1) global_step_val = global_step.numpy() losses.append( (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy())) mean_loss(train_loss) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='sc_loss', data=sc_loss, step=global_step_val) tf.compat.v2.summary.scalar(name='lam_loss', data=lam_loss, step=global_step_val) if global_step_val % summary_interval == 0: tf.compat.v2.summary.scalar(name=mean_loss.name, data=mean_loss.result(), step=global_step_val) if global_step_val % summary_interval == 0: with tf.name_scope('Metrics'): for metric in sc_metrics: if len(tf.squeeze(metric.result()).shape) == 0: tf.compat.v2.summary.scalar(name=metric.name, data=metric.result(), step=global_step_val) else: fmt_str = '_{}'.format(thresholds[0]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[0], step=global_step_val) fmt_str = '_{}'.format(thresholds[1]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[1], step=global_step_val) metric.reset_states() if global_step_val % eval_interval == 0: eval_sc(sc_net_off, step=global_step_val) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries( train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and global_step_val % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action( monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug( 'saved rollout at timestep %d, rollout length: %d, %4.2f sec', global_step_val, ep_len, time.time() - monitor_start) if global_step_val % train_checkpoint_interval == 0: sc_checkpointer.save(global_step=global_step_val)