def fn():
   for field in standard_fields.get_input_object_fields():
     if field in inputs:
       inputs[field] = tf.boolean_mask(inputs[field], valid_mask)
   for field in standard_fields.get_output_object_fields():
     if field in outputs:
       outputs[field] = tf.boolean_mask(outputs[field], valid_mask)
   return _box_corner_distance_loss(
       loss_type=loss_type,
       is_balanced=is_balanced,
       input_boxes_length=inputs[
           standard_fields.InputDataFields.objects_length],
       input_boxes_height=inputs[
           standard_fields.InputDataFields.objects_height],
       input_boxes_width=inputs[standard_fields.InputDataFields.objects_width],
       input_boxes_center=inputs[
           standard_fields.InputDataFields.objects_center],
       input_boxes_rotation_matrix=inputs[
           standard_fields.InputDataFields.objects_rotation_matrix],
       input_boxes_instance_id=inputs[
           standard_fields.InputDataFields.objects_instance_id],
       output_boxes_length=outputs[
           standard_fields.DetectionResultFields.objects_length],
       output_boxes_height=outputs[
           standard_fields.DetectionResultFields.objects_height],
       output_boxes_width=outputs[
           standard_fields.DetectionResultFields.objects_width],
       output_boxes_center=outputs[
           standard_fields.DetectionResultFields.objects_center],
       output_boxes_rotation_matrix=outputs[
           standard_fields.DetectionResultFields.objects_rotation_matrix],
       delta=delta)
Beispiel #2
0
  def update_state(self, inputs, outputs):
    """Function that updates the metric state at each example.

    Args:
      inputs: A dictionary containing input tensors.
      outputs: A dictionary containing output tensors.

    Returns:
      Update op.
    """
    detections_score = tf.reshape(
        outputs[standard_fields.DetectionResultFields.objects_score], [-1])
    detections_class = tf.reshape(
        outputs[standard_fields.DetectionResultFields.objects_class], [-1])
    num_detections = tf.shape(detections_score)[0]
    detections_instance_mask = tf.reshape(
        outputs[
            standard_fields.DetectionResultFields.instance_segments_voxel_mask],
        [num_detections, -1])
    gt_class = tf.reshape(inputs[standard_fields.InputDataFields.objects_class],
                          [-1])
    num_gt = tf.shape(gt_class)[0]
    gt_voxel_instance_ids = tf.reshape(
        inputs[standard_fields.InputDataFields.object_instance_id_voxels], [-1])
    gt_instance_masks = tf.transpose(
        tf.one_hot(gt_voxel_instance_ids - 1, depth=num_gt, dtype=tf.float32))
    for c in self.class_range:
      gt_mask_c = tf.equal(gt_class, c)
      num_gt_c = tf.math.reduce_sum(tf.cast(gt_mask_c, dtype=tf.int32))
      gt_instance_masks_c = tf.boolean_mask(gt_instance_masks, gt_mask_c)
      detections_mask_c = tf.equal(detections_class, c)
      num_detections_c = tf.math.reduce_sum(
          tf.cast(detections_mask_c, dtype=tf.int32))
      if num_detections_c == 0:
        continue
      det_scores_c = tf.boolean_mask(detections_score, detections_mask_c)
      det_instance_mask_c = tf.boolean_mask(detections_instance_mask,
                                            detections_mask_c)
      det_scores_c, sorted_indices = tf.math.top_k(
          det_scores_c, k=num_detections_c)
      det_instance_mask_c = tf.gather(det_instance_mask_c, sorted_indices)
      tp_c = tf.zeros([num_detections_c], dtype=tf.int32)
      if num_gt_c > 0:
        ious_c = instance_segmentation_utils.points_mask_iou(
            masks1=gt_instance_masks_c, masks2=det_instance_mask_c)
        max_overlap_gt_ids = tf.cast(
            tf.math.argmax(ious_c, axis=0), dtype=tf.int32)
        is_gt_box_detected = tf.zeros([num_gt_c], dtype=tf.int32)
        for i in tf.range(num_detections_c):
          gt_id = max_overlap_gt_ids[i]
          if (ious_c[gt_id, i] > self.iou_threshold and
              is_gt_box_detected[gt_id] == 0):
            tp_c = tf.maximum(
                tf.one_hot(i, num_detections_c, dtype=tf.int32), tp_c)
            is_gt_box_detected = tf.maximum(
                tf.one_hot(gt_id, num_gt_c, dtype=tf.int32), is_gt_box_detected)
      self.tp[c] = tf.concat([self.tp[c], tp_c], axis=0)
      self.scores[c] = tf.concat([self.scores[c], det_scores_c], axis=0)
      self.num_gt[c] += num_gt_c
    return tf.no_op()
Beispiel #3
0
def randomly_crop_points(mesh_inputs,
                         view_indices_2d_inputs,
                         x_random_crop_size,
                         y_random_crop_size,
                         epsilon=1e-5):
  """Randomly crops points.

  Args:
    mesh_inputs: A dictionary containing input mesh (point) tensors.
    view_indices_2d_inputs: A dictionary containing input point to view
      correspondence tensors.
    x_random_crop_size: Size of the random crop in x dimension. If None, random
      crop will not take place on x dimension.
    y_random_crop_size: Size of the random crop in y dimension. If None, random
      crop will not take place on y dimension.
    epsilon: Epsilon (a very small value) used to add as a small margin to
      thresholds.
  """
  if x_random_crop_size is None and y_random_crop_size is None:
    return

  points = mesh_inputs[standard_fields.InputDataFields.point_positions]
  num_points = tf.shape(points)[0]
  # Pick a random point
  if x_random_crop_size is not None or y_random_crop_size is not None:
    random_index = tf.random.uniform([],
                                     minval=0,
                                     maxval=num_points,
                                     dtype=tf.int32)
    center_x = points[random_index, 0]
    center_y = points[random_index, 1]

  points_x = points[:, 0]
  points_y = points[:, 1]
  min_x = tf.reduce_min(points_x) - epsilon
  max_x = tf.reduce_max(points_x) + epsilon
  min_y = tf.reduce_min(points_y) - epsilon
  max_y = tf.reduce_max(points_y) + epsilon

  if x_random_crop_size is not None:
    min_x = center_x - x_random_crop_size / 2.0 - epsilon
    max_x = center_x + x_random_crop_size / 2.0 + epsilon

  if y_random_crop_size is not None:
    min_y = center_y - y_random_crop_size / 2.0 - epsilon
    max_y = center_y + y_random_crop_size / 2.0 + epsilon

  x_mask = tf.logical_and(tf.greater(points_x, min_x), tf.less(points_x, max_x))
  y_mask = tf.logical_and(tf.greater(points_y, min_y), tf.less(points_y, max_y))
  points_mask = tf.logical_and(x_mask, y_mask)

  for key in sorted(mesh_inputs):
    mesh_inputs[key] = tf.boolean_mask(mesh_inputs[key], points_mask)

  for key in sorted(view_indices_2d_inputs):
    view_indices_2d_inputs[key] = tf.transpose(
        tf.boolean_mask(
            tf.transpose(view_indices_2d_inputs[key], [1, 0, 2]), points_mask),
        [1, 0, 2])
def _prepare_lidar_points(inputs, lidar_names):
    """Integrates and returns the lidar points in vehicle coordinate frame."""
    points_position = []
    points_intensity = []
    points_elongation = []
    points_normal = []
    points_in_image_frame_xy = []
    points_in_image_frame_id = []
    for lidar_name in lidar_names:
        lidar_location = tf.reshape(
            inputs[('lidars/%s/extrinsics/t') % lidar_name], [-1, 3])
        inside_no_label_zone = tf.reshape(
            inputs[('lidars/%s/pointcloud/inside_nlz' % lidar_name)], [-1])
        valid_points_mask = tf.math.logical_not(inside_no_label_zone)
        points_position_current_lidar = tf.boolean_mask(
            inputs[('lidars/%s/pointcloud/positions' % lidar_name)],
            valid_points_mask)
        points_position.append(points_position_current_lidar)
        points_intensity.append(
            tf.boolean_mask(
                inputs[('lidars/%s/pointcloud/intensity' % lidar_name)],
                valid_points_mask))
        points_elongation.append(
            tf.boolean_mask(
                inputs[('lidars/%s/pointcloud/elongation' % lidar_name)],
                valid_points_mask))
        points_to_lidar_vectors = lidar_location - points_position_current_lidar
        points_normal_direction = points_to_lidar_vectors / tf.expand_dims(
            tf.norm(points_to_lidar_vectors, axis=1), axis=1)
        points_normal.append(points_normal_direction)
        points_in_image_frame_xy.append(
            tf.boolean_mask(
                inputs['lidars/%s/camera_projections/positions' % lidar_name],
                valid_points_mask))
        points_in_image_frame_id.append(
            tf.boolean_mask(
                inputs['lidars/%s/camera_projections/ids' % lidar_name],
                valid_points_mask))
    points_position = tf.concat(points_position, axis=0)
    points_intensity = tf.concat(points_intensity, axis=0)
    points_elongation = tf.concat(points_elongation, axis=0)
    points_normal = tf.concat(points_normal, axis=0)
    points_in_image_frame_xy = tf.concat(points_in_image_frame_xy, axis=0)
    points_in_image_frame_id = tf.cast(tf.concat(points_in_image_frame_id,
                                                 axis=0),
                                       dtype=tf.int32)
    points_in_image_frame_yx = tf.cast(tf.reverse(points_in_image_frame_xy,
                                                  axis=[-1]),
                                       dtype=tf.int32)

    return (points_position, points_intensity, points_elongation,
            points_normal, points_in_image_frame_yx, points_in_image_frame_id)
Beispiel #5
0
    def convert_to_simclr_episode(support_images=None,
                                  support_labels=None,
                                  support_class_ids=None,
                                  query_images=None,
                                  query_labels=None,
                                  query_class_ids=None):
        """Convert a single episode into a SimCLR Episode."""

        # If there were k query examples of class c, keep the first k support
        # examples of class c as 'simclr' queries.  We do this by assigning an
        # id for each image in the query set, implemented as label*1e5+x+1, where
        # x is the number of images of the same label with a lower index within
        # the query set.  We do the same for the support set, which gives us a
        # mapping between query and support images which is injective (as long
        # as there's enough support-set images of each class).
        #
        # note: assumes max support label is 10000 - max_images_per_class
        query_idx_within_class = tf.cast(
            tf.equal(query_labels[tf.newaxis, :], query_labels[:, tf.newaxis]),
            tf.int32)
        query_idx_within_class = tf.linalg.diag_part(
            tf.cumsum(query_idx_within_class, axis=1))
        query_uid = query_labels * 10000 + query_idx_within_class
        support_idx_within_class = tf.cast(
            tf.equal(support_labels[tf.newaxis, :],
                     support_labels[:, tf.newaxis]), tf.int32)
        support_idx_within_class = tf.linalg.diag_part(
            tf.cumsum(support_idx_within_class, axis=1))
        support_uid = support_labels * 10000 + support_idx_within_class

        # compute which support-set images have matches in the query set, and
        # discard the rest to produce the new query set.
        support_keep = tf.reduce_any(tf.equal(support_uid[:, tf.newaxis],
                                              query_uid[tf.newaxis, :]),
                                     axis=1)
        query_images = tf.boolean_mask(support_images, support_keep)

        support_labels = tf.range(tf.shape(support_labels)[0],
                                  dtype=support_labels.dtype)
        query_labels = tf.boolean_mask(support_labels, support_keep)
        query_class_ids = tf.boolean_mask(support_class_ids, support_keep)

        # Finally, apply SimCLR augmentation to all images.
        # Note simclr only blurs one image.
        query_images = simclr_augment(query_images, blur=True)
        support_images = simclr_augment(support_images)

        return (support_images, support_labels, support_class_ids,
                query_images, query_labels, query_class_ids)
def _remove_second_return_lidar_points(mesh_inputs, view_indices_2d_inputs):
    """removes the points that are not lidar first-return ."""
    if standard_fields.InputDataFields.point_spin_coordinates not in mesh_inputs:
        raise ValueError('spin_coordinates not in mesh_inputs.')
    first_return_mask = tf.equal(
        tf.cast(mesh_inputs[
            standard_fields.InputDataFields.point_spin_coordinates][:, 2],
                dtype=tf.int32), 0)
    for key in sorted(mesh_inputs):
        mesh_inputs[key] = tf.boolean_mask(mesh_inputs[key], first_return_mask)
    for key in sorted(view_indices_2d_inputs):
        view_indices_2d_inputs[key] = tf.transpose(
            tf.boolean_mask(
                tf.transpose(view_indices_2d_inputs[key], [1, 0, 2]),
                first_return_mask), [1, 0, 2])
Beispiel #7
0
def _filter_valid_objects(inputs):
  """Removes the objects that do not contain 3d info.

  Args:
    inputs: A dictionary containing input tensors.
  """
  if standard_fields.InputDataFields.objects_class not in inputs:
    return

  valid_objects_mask = tf.reshape(
      tf.greater(inputs[standard_fields.InputDataFields.objects_class], 0),
      [-1])
  if standard_fields.InputDataFields.objects_has_3d_info in inputs:
    objects_with_3d_info = tf.reshape(
        tf.cast(
            inputs[standard_fields.InputDataFields.objects_has_3d_info],
            dtype=tf.bool), [-1])
    valid_objects_mask = tf.logical_and(objects_with_3d_info,
                                        valid_objects_mask)
  if standard_fields.InputDataFields.objects_difficulty in inputs:
    valid_objects_mask = tf.logical_and(
        valid_objects_mask,
        tf.greater(
            tf.reshape(
                inputs[standard_fields.InputDataFields.objects_difficulty],
                [-1]), 0))
  for key in _OBJECT_KEYS:
    if key in inputs:
      inputs[key] = tf.boolean_mask(inputs[key], valid_objects_mask)
Beispiel #8
0
def experience_to_transitions(experience):
    boundary_mask = tf.logical_not(experience.is_boundary()[:, 0])
    experience = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, boundary_mask), experience)
    time_steps, policy_steps, next_time_steps = (
        trajectory.experience_to_transitions(experience, True))
    actions = policy_steps.action
    return time_steps, actions, next_time_steps
Beispiel #9
0
def embedding_regularization_loss(inputs,
                                  outputs,
                                  lambda_coef=0.0001,
                                  regularization_type='unit_length',
                                  is_intermediate=False):
  """Classification loss with an iou threshold.

  Args:
    inputs: A dictionary that contains
      num_valid_voxels - A tf.int32 tensor of size [batch_size].
      instance_ids - A tf.int32 tensor of size [batch_size, n].
    outputs: A dictionart that contains
      embeddings - A tf.float32 tensor of size [batch_size, n, f].
    lambda_coef: Regularization loss coefficient.
    regularization_type: Regularization loss type. Supported values are 'msq'
      and 'unit_length'. 'msq' stands for 'mean square' which penalizes the
      embedding vectors if they have a length far from zero. 'unit_length'
      penalizes the embedding vectors if they have a length far from one.
    is_intermediate: True if applied to intermediate predictions;
      otherwise, False.

  Returns:
    A tf.float32 scalar loss tensor.
  """
  instance_ids_key = standard_fields.InputDataFields.object_instance_id_voxels
  num_voxels_key = standard_fields.InputDataFields.num_valid_voxels
  if is_intermediate:
    embedding_key = (
        standard_fields.DetectionResultFields
        .intermediate_instance_embedding_voxels)
  else:
    embedding_key = (
        standard_fields.DetectionResultFields.instance_embedding_voxels)
  if instance_ids_key not in inputs:
    raise ValueError('instance_ids is missing in inputs.')
  if embedding_key not in outputs:
    raise ValueError('embedding is missing in outputs.')
  if num_voxels_key not in inputs:
    raise ValueError('num_voxels is missing in inputs.')
  batch_size = inputs[num_voxels_key].get_shape().as_list()[0]
  if batch_size is None:
    raise ValueError('batch_size is not defined at graph construction time.')
  num_valid_voxels = inputs[num_voxels_key]
  num_voxels = tf.shape(inputs[instance_ids_key])[1]
  valid_mask = tf.less(
      tf.tile(tf.expand_dims(tf.range(num_voxels), axis=0), [batch_size, 1]),
      tf.expand_dims(num_valid_voxels, axis=1))
  valid_mask = tf.reshape(valid_mask, [-1])
  embedding_dims = outputs[embedding_key].get_shape().as_list()[-1]
  if embedding_dims is None:
    raise ValueError(
        'Embedding dimension is unknown at graph construction time.')
  embedding = tf.reshape(outputs[embedding_key], [-1, embedding_dims])
  embedding = tf.boolean_mask(embedding, valid_mask)
  return metric_learning_losses.regularization_loss(
      embedding=embedding,
      lambda_coef=lambda_coef,
      regularization_type=regularization_type)
def compute_target_optimal_q(reward, gamma, next_actions, next_q_values,
                             next_states, terminals):
    """Builds an op used as a target for the Q-value.

  This algorithm corresponds to the method "OT" in
  Ie et al. https://arxiv.org/abs/1905.12767..

  Args:
    reward: [batch_size] tensor, the immediate reward.
    gamma: float, discount factor with the usual RL meaning.
    next_actions: [batch_size, slate_size] tensor, the next slate.
    next_q_values: [batch_size, num_of_documents] tensor, the q values of the
      documents in the next step.
    next_states: [batch_size, 1 + num_of_documents] tensor, the features for the
      user and the docuemnts in the next step.
    terminals: [batch_size] tensor, indicating if this is a terminal step.

  Returns:
    [batch_size] tensor, the target q values.
  """
    scores, score_no_click = _get_unnormalized_scores(next_states)

    # Obtain all possible slates given current docs in the candidate set.
    slate_size = next_actions.get_shape().as_list()[1]
    num_candidates = next_q_values.get_shape().as_list()[1]
    mesh_args = [list(range(num_candidates))] * slate_size
    slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1)
    slates = tf.reshape(slates, shape=(-1, slate_size))
    # Filter slates that include duplicates to ensure each document is picked
    # at most once.
    unique_mask = tf.map_fn(
        lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])),
        slates,
        dtype=tf.bool)
    # [num_of_slates, slate_size]
    slates = tf.boolean_mask(tensor=slates, mask=unique_mask)

    # [batch_size, num_of_slates, slate_size]
    next_q_values_slate = tf.gather(next_q_values, slates, axis=1)
    # [batch_size, num_of_slates, slate_size]
    scores_slate = tf.gather(scores, slates, axis=1)
    # [batch_size, num_of_slates]
    batch_size = next_states.get_shape().as_list()[0]
    score_no_click_slate = tf.reshape(
        tf.tile(score_no_click,
                tf.shape(input=slates)[:1]), [batch_size, -1])

    # [batch_size, num_of_slates]
    next_q_target_slate = tf.reduce_sum(
        input_tensor=next_q_values_slate * scores_slate,
        axis=2) / (tf.reduce_sum(input_tensor=scores_slate, axis=2) +
                   score_no_click_slate)

    next_q_target_max = tf.reduce_max(input_tensor=next_q_target_slate, axis=1)

    return reward + gamma * next_q_target_max * (
        1. - tf.cast(terminals, tf.float32))
 def loss_fn():
     """Loss function."""
     num_classes = logits.get_shape().as_list()[-1]
     if num_classes is None:
         raise ValueError('Number of classes is unknown.')
     masked_logits = tf.boolean_mask(logits, background_mask)
     masked_weights = tf.pow(
         1.0 - tf.reshape(tf.nn.softmax(masked_logits)[:, 0], [-1, 1]),
         gamma)
     num_points = tf.shape(masked_logits)[0]
     masked_weights = masked_weights * tf.cast(
         num_points, dtype=tf.float32) / tf.reduce_sum(masked_weights)
     masked_labels_one_hot = tf.one_hot(indices=tf.boolean_mask(
         labels, background_mask),
                                        depth=num_classes)
     loss = classification_loss_fn(logits=masked_logits,
                                   labels=masked_labels_one_hot,
                                   weights=masked_weights)
     return loss
 def _body_fn(i, indices_range, indices):
   """Computes the indices of the i-th point feature in each segment."""
   indices_i = tf.math.unsorted_segment_max(
       data=indices_range, segment_ids=segment_ids, num_segments=num_segments)
   indices_i_positive_mask = tf.greater(indices_i, 0)
   indices_i_positive = tf.boolean_mask(indices_i, indices_i_positive_mask)
   boolean_mask = tf.scatter_nd(
       indices=tf.cast(
           tf.expand_dims(indices_i_positive - 1, axis=1), dtype=tf.int64),
       updates=tf.ones_like(indices_i_positive, dtype=tf.int32),
       shape=(n,))
   indices_range *= (1 - boolean_mask)
   indices_i *= tf.cast(indices_i_positive_mask, dtype=tf.int32)
   indices_i = tf.pad(
       tf.expand_dims(indices_i, axis=1),
       paddings=[[0, 0], [i, num_samples_per_voxel - i - 1]])
   indices += indices_i
   i = i + 1
   return i, indices_range, indices
def select_slate_optimal(slate_size, s_no_click, s, q):
    """Selects the slate using exhaustive search.

  This algorithm corresponds to the method "OS" in
  Ie et al. https://arxiv.org/abs/1905.12767.

  Args:
    slate_size: int, the size of the recommendation slate.
    s_no_click: float tensor, the score for not clicking any document.
    s: [num_of_documents] tensor, the scores for clicking documents.
    q: [num_of_documents] tensor, the predicted q values for documents.

  Returns:
    [slate_size] tensor, the selected slate.
  """

    num_candidates = s.shape.as_list()[0]

    # Obtain all possible slates given current docs in the candidate set.
    mesh_args = [list(range(num_candidates))] * slate_size
    slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1)
    slates = tf.reshape(slates, shape=(-1, slate_size))

    # Filter slates that include duplicates to ensure each document is picked
    # at most once.
    unique_mask = tf.map_fn(
        lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])),
        slates,
        dtype=tf.bool)
    slates = tf.boolean_mask(tensor=slates, mask=unique_mask)

    slate_q_values = tf.gather(s * q, slates)
    slate_scores = tf.gather(s, slates)
    slate_normalizer = tf.reduce_sum(input_tensor=slate_scores,
                                     axis=1) + s_no_click

    slate_q_values = slate_q_values / tf.expand_dims(slate_normalizer, 1)
    slate_sum_q_values = tf.reduce_sum(input_tensor=slate_q_values, axis=1)
    max_q_slate_index = tf.argmax(input=slate_sum_q_values)
    return tf.gather(slates, max_q_slate_index, axis=0)
def _box_classification_loss_unbatched(inputs_1, outputs_1, is_intermediate,
                                       is_balanced, mine_hard_negatives,
                                       hard_negative_score_threshold):
    """Loss function for input and outputs of batch size 1."""
    valid_mask = _get_voxels_valid_mask(inputs_1=inputs_1)
    if is_intermediate:
        logits = outputs_1[standard_fields.DetectionResultFields.
                           intermediate_object_semantic_voxels]
    else:
        logits = outputs_1[
            standard_fields.DetectionResultFields.object_semantic_voxels]
    num_classes = logits.get_shape().as_list()[-1]
    if num_classes is None:
        raise ValueError('Number of classes is unknown.')
    logits = tf.boolean_mask(tf.reshape(logits, [-1, num_classes]), valid_mask)
    labels = tf.boolean_mask(
        tf.reshape(
            inputs_1[standard_fields.InputDataFields.object_class_voxels],
            [-1, 1]), valid_mask)
    if mine_hard_negatives or is_balanced:
        instances = tf.boolean_mask(
            tf.reshape(
                inputs_1[
                    standard_fields.InputDataFields.object_instance_id_voxels],
                [-1]), valid_mask)
    params = {}
    if mine_hard_negatives:
        negative_scores = tf.reshape(tf.nn.softmax(logits)[:, 0], [-1])
        hard_negative_mask = tf.logical_and(
            tf.less(negative_scores, hard_negative_score_threshold),
            tf.equal(tf.reshape(labels, [-1]), 0))
        hard_negative_labels = tf.boolean_mask(labels, hard_negative_mask)
        hard_negative_logits = tf.boolean_mask(logits, hard_negative_mask)
        hard_negative_instances = tf.boolean_mask(
            tf.ones_like(instances) * (tf.reduce_max(instances) + 1),
            hard_negative_mask)
        logits = tf.concat([logits, hard_negative_logits], axis=0)
        instances = tf.concat([instances, hard_negative_instances], axis=0)
        labels = tf.concat([labels, hard_negative_labels], axis=0)
    if is_balanced:
        weights = loss_utils.get_balanced_loss_weights_multiclass(
            labels=tf.expand_dims(instances, axis=1))
        params['weights'] = weights
    return classification_loss_fn(logits=logits, labels=labels, **params)
Beispiel #15
0
def prepare_kitti_dataset(inputs, valid_object_classes=None):
  """Maps the fields from loaded input to standard fields.

  Args:
    inputs: A dictionary of input tensors.
    valid_object_classes: List of valid object classes. if None, it is ignored.

  Returns:
    A dictionary of input tensors with standard field names.
  """
  prepared_inputs = {}
  prepared_inputs[standard_fields.InputDataFields.point_positions] = inputs[
      standard_fields.InputDataFields.point_positions]
  prepared_inputs[standard_fields.InputDataFields.point_intensities] = inputs[
      standard_fields.InputDataFields.point_intensities]
  prepared_inputs[standard_fields.InputDataFields
                  .camera_intrinsics] = inputs['cameras/cam02/intrinsics/K']
  prepared_inputs[standard_fields.InputDataFields.
                  camera_rotation_matrix] = inputs['cameras/cam02/extrinsics/R']
  prepared_inputs[standard_fields.InputDataFields
                  .camera_translation] = inputs['cameras/cam02/extrinsics/t']
  prepared_inputs[standard_fields.InputDataFields
                  .camera_image] = inputs['cameras/cam02/image']
  prepared_inputs[standard_fields.InputDataFields
                  .camera_raw_image] = inputs['cameras/cam02/image']
  prepared_inputs[standard_fields.InputDataFields
                  .camera_original_image] = inputs['cameras/cam02/image']
  if 'scene_name' in inputs and 'frame_name' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.camera_image_name] = tf.strings.join(
            [inputs['scene_name'], inputs['frame_name']], separator='_')
  if 'objects/pose/R' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .objects_rotation_matrix] = inputs['objects/pose/R']
  if 'objects/pose/t' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .objects_center] = inputs['objects/pose/t']
  if 'objects/shape/dimension' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.objects_length] = tf.reshape(
            inputs['objects/shape/dimension'][:, 0], [-1, 1])
    prepared_inputs[standard_fields.InputDataFields.objects_width] = tf.reshape(
        inputs['objects/shape/dimension'][:, 1], [-1, 1])
    prepared_inputs[
        standard_fields.InputDataFields.objects_height] = tf.reshape(
            inputs['objects/shape/dimension'][:, 2], [-1, 1])
  if 'objects/category/label' in inputs:
    prepared_inputs[standard_fields.InputDataFields.objects_class] = tf.reshape(
        inputs['objects/category/label'], [-1, 1])
  if valid_object_classes is not None:
    valid_objects_mask = tf.cast(
        tf.zeros_like(
            prepared_inputs[standard_fields.InputDataFields.objects_class],
            dtype=tf.int32),
        dtype=tf.bool)
    for object_class in valid_object_classes:
      valid_objects_mask = tf.logical_or(
          valid_objects_mask,
          tf.equal(
              prepared_inputs[standard_fields.InputDataFields.objects_class],
              object_class))
    valid_objects_mask = tf.reshape(valid_objects_mask, [-1])
    for key in standard_fields.get_input_object_fields():
      if key in prepared_inputs:
        prepared_inputs[key] = tf.boolean_mask(prepared_inputs[key],
                                               valid_objects_mask)

  return prepared_inputs
Beispiel #16
0
def mask_tensor(x, s):
    not_x = tf.boolean_mask(x, tf.logical_not(s))
    x = tf.boolean_mask(x, s)
    return x, not_x
Beispiel #17
0
def prepare_waymo_open_dataset(inputs,
                               valid_object_classes=None,
                               max_object_distance_from_source=74.88):
  """Maps the fields from loaded input to standard fields.

  Args:
    inputs: A dictionary of input tensors.
    valid_object_classes: List of valid object classes. if None, it is ignored.
    max_object_distance_from_source: Maximum distance of objects from source. It
      will be ignored if None.

  Returns:
    A dictionary of input tensors with standard field names.
  """
  prepared_inputs = {}
  if standard_fields.InputDataFields.point_positions in inputs:
    prepared_inputs[standard_fields.InputDataFields.point_positions] = inputs[
        standard_fields.InputDataFields.point_positions]
  if standard_fields.InputDataFields.point_intensities in inputs:
    prepared_inputs[standard_fields.InputDataFields.point_intensities] = inputs[
        standard_fields.InputDataFields.point_intensities]
  if standard_fields.InputDataFields.point_elongations in inputs:
    prepared_inputs[standard_fields.InputDataFields.point_elongations] = inputs[
        standard_fields.InputDataFields.point_elongations]
  if standard_fields.InputDataFields.point_normals in inputs:
    prepared_inputs[standard_fields.InputDataFields.point_normals] = inputs[
        standard_fields.InputDataFields.point_normals]
  if 'cameras/front/intrinsics/K' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .camera_intrinsics] = inputs['cameras/front/intrinsics/K']
  if 'cameras/front/extrinsics/R' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields
        .camera_rotation_matrix] = inputs['cameras/front/extrinsics/R']
  if 'cameras/front/extrinsics/t' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .camera_translation] = inputs['cameras/front/extrinsics/t']
  if 'cameras/front/image' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .camera_image] = inputs['cameras/front/image']
    prepared_inputs[standard_fields.InputDataFields
                    .camera_raw_image] = inputs['cameras/front/image']
    prepared_inputs[standard_fields.InputDataFields
                    .camera_original_image] = inputs['cameras/front/image']
  if 'scene_name' in inputs and 'frame_name' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.camera_image_name] = tf.strings.join(
            [inputs['scene_name'], inputs['frame_name']], separator='_')
  if 'objects/pose/R' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .objects_rotation_matrix] = inputs['objects/pose/R']
  if 'objects/pose/t' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .objects_center] = inputs['objects/pose/t']
  if 'objects/shape/dimension' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.objects_length] = tf.reshape(
            inputs['objects/shape/dimension'][:, 0], [-1, 1])
    prepared_inputs[standard_fields.InputDataFields.objects_width] = tf.reshape(
        inputs['objects/shape/dimension'][:, 1], [-1, 1])
    prepared_inputs[
        standard_fields.InputDataFields.objects_height] = tf.reshape(
            inputs['objects/shape/dimension'][:, 2], [-1, 1])
  if 'objects/category/label' in inputs:
    prepared_inputs[standard_fields.InputDataFields.objects_class] = tf.reshape(
        inputs['objects/category/label'], [-1, 1])
  if valid_object_classes is not None:
    valid_objects_mask = tf.cast(
        tf.zeros_like(
            prepared_inputs[standard_fields.InputDataFields.objects_class],
            dtype=tf.int32),
        dtype=tf.bool)
    for object_class in valid_object_classes:
      valid_objects_mask = tf.logical_or(
          valid_objects_mask,
          tf.equal(
              prepared_inputs[standard_fields.InputDataFields.objects_class],
              object_class))
    valid_objects_mask = tf.reshape(valid_objects_mask, [-1])
    for key in standard_fields.get_input_object_fields():
      if key in prepared_inputs:
        prepared_inputs[key] = tf.boolean_mask(prepared_inputs[key],
                                               valid_objects_mask)

  if max_object_distance_from_source is not None:
    if standard_fields.InputDataFields.objects_center in prepared_inputs:
      object_distances = tf.norm(
          prepared_inputs[standard_fields.InputDataFields.objects_center][:,
                                                                          0:2],
          axis=1)
      valid_mask = tf.less(object_distances, max_object_distance_from_source)
      for key in standard_fields.get_input_object_fields():
        if key in prepared_inputs:
          prepared_inputs[key] = tf.boolean_mask(prepared_inputs[key],
                                                 valid_mask)

  return prepared_inputs
Beispiel #18
0
def prepare_scannet_frame_dataset(inputs,
                                  min_pixel_depth=0.3,
                                  max_pixel_depth=6.0,
                                  valid_object_classes=None):
  """Maps the fields from loaded input to standard fields.

  Args:
    inputs: A dictionary of input tensors.
    min_pixel_depth: Pixels with depth values less than this are pruned.
    max_pixel_depth: Pixels with depth values more than this are pruned.
    valid_object_classes: List of valid object classes. if None, it is ignored.

  Returns:
    A dictionary of input tensors with standard field names.
  """
  prepared_inputs = {}
  if 'cameras/rgbd_camera/intrinsics/K' not in inputs:
    raise ValueError('Intrinsic matrix is missing.')
  if 'cameras/rgbd_camera/extrinsics/R' not in inputs:
    raise ValueError('Extrinsic rotation matrix is missing.')
  if 'cameras/rgbd_camera/extrinsics/t' not in inputs:
    raise ValueError('Extrinsics translation is missing.')
  if 'cameras/rgbd_camera/depth_image' not in inputs:
    raise ValueError('Depth image is missing.')
  if 'cameras/rgbd_camera/color_image' not in inputs:
    raise ValueError('Color image is missing.')
  if 'frame_name' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .camera_image_name] = inputs['frame_name']
  camera_intrinsics = inputs['cameras/rgbd_camera/intrinsics/K']
  depth_image = inputs['cameras/rgbd_camera/depth_image']
  image_height = tf.shape(depth_image)[0]
  image_width = tf.shape(depth_image)[1]
  x, y = tf.meshgrid(
      tf.range(image_width), tf.range(image_height), indexing='xy')
  x = tf.reshape(tf.cast(x, dtype=tf.float32) + 0.5, [-1, 1])
  y = tf.reshape(tf.cast(y, dtype=tf.float32) + 0.5, [-1, 1])
  point_positions = projections.image_frame_to_camera_frame(
      image_frame=tf.concat([x, y], axis=1),
      camera_intrinsics=camera_intrinsics)
  rotate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/R']
  translate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/t']
  point_positions = projections.to_world_frame(
      camera_frame_points=point_positions,
      rotate_world_to_camera=rotate_world_to_camera,
      translate_world_to_camera=translate_world_to_camera)
  prepared_inputs[standard_fields.InputDataFields
                  .point_positions] = point_positions * tf.reshape(
                      depth_image, [-1, 1])
  depth_values = tf.reshape(depth_image, [-1])
  valid_depth_mask = tf.logical_and(
      tf.greater_equal(depth_values, min_pixel_depth),
      tf.less_equal(depth_values, max_pixel_depth))
  prepared_inputs[standard_fields.InputDataFields.point_colors] = tf.reshape(
      tf.cast(inputs['cameras/rgbd_camera/color_image'], dtype=tf.float32),
      [-1, 3])
  prepared_inputs[standard_fields.InputDataFields.point_colors] *= (2.0 / 255.0)
  prepared_inputs[standard_fields.InputDataFields.point_colors] -= 1.0
  prepared_inputs[
      standard_fields.InputDataFields.point_positions] = tf.boolean_mask(
          prepared_inputs[standard_fields.InputDataFields.point_positions],
          valid_depth_mask)
  prepared_inputs[
      standard_fields.InputDataFields.point_colors] = tf.boolean_mask(
          prepared_inputs[standard_fields.InputDataFields.point_colors],
          valid_depth_mask)
  if 'cameras/rgbd_camera/semantic_image' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.object_class_points] = tf.cast(
            tf.reshape(inputs['cameras/rgbd_camera/semantic_image'], [-1, 1]),
            dtype=tf.int32)
    prepared_inputs[
        standard_fields.InputDataFields.object_class_points] = tf.boolean_mask(
            prepared_inputs[
                standard_fields.InputDataFields.object_class_points],
            valid_depth_mask)
  if 'cameras/rgbd_camera/instance_image' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.object_instance_id_points] = tf.cast(
            tf.reshape(inputs['cameras/rgbd_camera/instance_image'], [-1]),
            dtype=tf.int32)
    prepared_inputs[standard_fields.InputDataFields
                    .object_instance_id_points] = tf.boolean_mask(
                        prepared_inputs[standard_fields.InputDataFields
                                        .object_instance_id_points],
                        valid_depth_mask)

  if valid_object_classes is not None:
    valid_objects_mask = tf.cast(
        tf.zeros_like(
            prepared_inputs[
                standard_fields.InputDataFields.object_class_points],
            dtype=tf.int32),
        dtype=tf.bool)
    for object_class in valid_object_classes:
      valid_objects_mask = tf.logical_or(
          valid_objects_mask,
          tf.equal(
              prepared_inputs[
                  standard_fields.InputDataFields.object_class_points],
              object_class))
    valid_objects_mask = tf.cast(
        valid_objects_mask,
        dtype=prepared_inputs[
            standard_fields.InputDataFields.object_class_points].dtype)
    prepared_inputs[standard_fields.InputDataFields
                    .object_class_points] *= valid_objects_mask
  return prepared_inputs
def classification_loss_using_mask_iou_func(embeddings,
                                            logits,
                                            instance_ids,
                                            class_labels,
                                            num_samples,
                                            valid_mask=None,
                                            max_instance_id=None,
                                            similarity_strategy='dotproduct',
                                            is_balanced=True):
    """Classification loss using mask iou.

  Args:
    embeddings: A tf.float32 tensor of size [batch_size, n, f].
    logits: A tf.float32 tensor of size [batch_size, n, num_classes]. It is
      assumed that background is class 0.
    instance_ids: A tf.int32 tensor of size [batch_size, n].
    class_labels: A tf.int32 tensor of size [batch_size, n]. It is assumed
      that the background voxels are assigned to class 0.
    num_samples: An int determining the number of samples.
    valid_mask: A tf.bool tensor of size [batch_size, n] that is True when an
      element is valid and False if it needs to be ignored. By default the value
      is None which means it is not applied.
    max_instance_id: If set, instance ids larger than that value will be
      ignored. If not set, it will be computed from instance_ids tensor.
    similarity_strategy: Defines the method for computing similarity between
                         embedding vectors. Possible values are 'dotproduct' and
                         'distance'.
    is_balanced: If True, the per-voxel losses are re-weighted to have equal
      total weight for foreground vs. background voxels.

  Returns:
    A tf.float32 scalar loss tensor.
  """
    batch_size = embeddings.get_shape().as_list()[0]
    if batch_size is None:
        raise ValueError('Unknown batch size at graph construction time.')
    if max_instance_id is None:
        max_instance_id = tf.reduce_max(instance_ids)
    class_labels = tf.reshape(class_labels, [batch_size, -1, 1])
    sampled_embeddings, sampled_instance_ids, sampled_indices = (
        sampling_utils.balanced_sample(features=embeddings,
                                       instance_ids=instance_ids,
                                       num_samples=num_samples,
                                       valid_mask=valid_mask,
                                       max_instance_id=max_instance_id))
    losses = []
    for i in range(batch_size):
        embeddings_i = embeddings[i, :, :]
        instance_ids_i = instance_ids[i, :]
        class_labels_i = class_labels[i, :, :]
        logits_i = logits[i, :]
        sampled_embeddings_i = sampled_embeddings[i, :, :]
        sampled_instance_ids_i = sampled_instance_ids[i, :]
        sampled_indices_i = sampled_indices[i, :]
        sampled_class_labels_i = tf.gather(class_labels_i, sampled_indices_i)
        sampled_logits_i = tf.gather(logits_i, sampled_indices_i)
        if valid_mask is not None:
            valid_mask_i = valid_mask[i]
            embeddings_i = tf.boolean_mask(embeddings_i, valid_mask_i)
            instance_ids_i = tf.boolean_mask(instance_ids_i, valid_mask_i)
        loss_i = classification_loss_using_mask_iou_func_unbatched(
            embeddings=embeddings_i,
            instance_ids=instance_ids_i,
            sampled_embeddings=sampled_embeddings_i,
            sampled_instance_ids=sampled_instance_ids_i,
            sampled_class_labels=sampled_class_labels_i,
            sampled_logits=sampled_logits_i,
            similarity_strategy=similarity_strategy,
            is_balanced=is_balanced)
        losses.append(loss_i)
    return tf.math.reduce_mean(tf.stack(losses))
Beispiel #20
0
    def update_state(self, inputs, outputs):
        """Function that updates the metric state at each example.

    Args:
      inputs: A dictionary containing input tensors.
      outputs: A dictionary containing output tensors.

    Returns:
      Update op.
    """
        detections_score = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_score], [-1])
        detections_class = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_class], [-1])
        detections_length = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_length],
            [-1])
        detections_height = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_height],
            [-1])
        detections_width = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_width], [-1])
        detections_center = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_center],
            [-1, 3])
        detections_rotation_matrix = tf.reshape(
            outputs[
                standard_fields.DetectionResultFields.objects_rotation_matrix],
            [-1, 3, 3])
        gt_class = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_class], [-1])
        gt_length = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_length], [-1])
        gt_height = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_height], [-1])
        gt_width = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_width], [-1])
        gt_center = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_center], [-1, 3])
        gt_rotation_matrix = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_rotation_matrix],
            [-1, 3, 3])
        for c in self.class_range:
            gt_mask_c = tf.equal(gt_class, c)
            num_gt_c = tf.math.reduce_sum(tf.cast(gt_mask_c, dtype=tf.int32))
            gt_length_c = tf.boolean_mask(gt_length, gt_mask_c)
            gt_height_c = tf.boolean_mask(gt_height, gt_mask_c)
            gt_width_c = tf.boolean_mask(gt_width, gt_mask_c)
            gt_center_c = tf.boolean_mask(gt_center, gt_mask_c)
            gt_rotation_matrix_c = tf.boolean_mask(gt_rotation_matrix,
                                                   gt_mask_c)
            detections_mask_c = tf.equal(detections_class, c)
            num_detections_c = tf.math.reduce_sum(
                tf.cast(detections_mask_c, dtype=tf.int32))
            if num_detections_c == 0:
                continue
            det_length_c = tf.boolean_mask(detections_length,
                                           detections_mask_c)
            det_height_c = tf.boolean_mask(detections_height,
                                           detections_mask_c)
            det_width_c = tf.boolean_mask(detections_width, detections_mask_c)
            det_center_c = tf.boolean_mask(detections_center,
                                           detections_mask_c)
            det_rotation_matrix_c = tf.boolean_mask(detections_rotation_matrix,
                                                    detections_mask_c)
            det_scores_c = tf.boolean_mask(detections_score, detections_mask_c)
            det_scores_c, sorted_indices = tf.math.top_k(det_scores_c,
                                                         k=num_detections_c)
            det_length_c = tf.gather(det_length_c, sorted_indices)
            det_height_c = tf.gather(det_height_c, sorted_indices)
            det_width_c = tf.gather(det_width_c, sorted_indices)
            det_center_c = tf.gather(det_center_c, sorted_indices)
            det_rotation_matrix_c = tf.gather(det_rotation_matrix_c,
                                              sorted_indices)
            tp_c = tf.zeros([num_detections_c], dtype=tf.int32)
            if num_gt_c > 0:
                ious_c = box_ops.iou3d(
                    boxes1_length=gt_length_c,
                    boxes1_height=gt_height_c,
                    boxes1_width=gt_width_c,
                    boxes1_center=gt_center_c,
                    boxes1_rotation_matrix=gt_rotation_matrix_c,
                    boxes2_length=det_length_c,
                    boxes2_height=det_height_c,
                    boxes2_width=det_width_c,
                    boxes2_center=det_center_c,
                    boxes2_rotation_matrix=det_rotation_matrix_c)
                max_overlap_gt_ids = tf.cast(tf.math.argmax(ious_c, axis=0),
                                             dtype=tf.int32)
                is_gt_box_detected = tf.zeros([num_gt_c], dtype=tf.int32)
                for i in tf.range(num_detections_c):
                    gt_id = max_overlap_gt_ids[i]
                    if (ious_c[gt_id, i] > self.iou_threshold
                            and is_gt_box_detected[gt_id] == 0):
                        tp_c = tf.maximum(
                            tf.one_hot(i, num_detections_c, dtype=tf.int32),
                            tp_c)
                        is_gt_box_detected = tf.maximum(
                            tf.one_hot(gt_id, num_gt_c, dtype=tf.int32),
                            is_gt_box_detected)
            self.tp[c] = tf.concat([self.tp[c], tp_c], axis=0)
            self.scores[c] = tf.concat([self.scores[c], det_scores_c], axis=0)
            self.num_gt[c] += num_gt_c
        return tf.no_op()
Beispiel #21
0
def _non_nan_mean(tensor_list):
  """Calculates the mean of a list of tensors while ignoring nans."""
  tensor = tf.stack(tensor_list)
  not_nan = tf.logical_not(tf.math.is_nan(tensor))
  return tf.reduce_mean(tf.boolean_mask(tensor, not_nan))
Beispiel #22
0
def train_eval(
        load_root_dir,
        env_load_fn=None,
        gym_env_wrappers=[],
        monitor=False,
        env_name=None,
        agent_class=None,
        train_metrics_callback=None,
        # SacAgent args
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256),
        # Safety Critic training args
        safety_critic_joint_fc_layers=None,
        safety_critic_lr=3e-4,
        safety_critic_bias_init_val=None,
        safety_critic_kernel_scale=None,
        n_envs=None,
        target_safety=0.2,
        fail_weight=None,
        # Params for train
        num_global_steps=10000,
        batch_size=256,
        # Params for eval
        run_eval=False,
        eval_metrics=[],
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        summary_interval=1000,
        monitor_interval=5000,
        summaries_flush_secs=10,
        debug_summaries=False,
        seed=None):

    if isinstance(agent_class, str):
        assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(
            agent_class)
        agent_class = ALGOS.get(agent_class)

    train_ckpt_dir = osp.join(load_root_dir, 'train')
    rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer')

    py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)

    if monitor:
        vid_path = os.path.join(load_root_dir, 'rollouts')
        monitor_env_wrapper = misc.monitor_freq(1, vid_path)
        monitor_env = gym.make(env_name)
        for wrapper in gym_env_wrappers:
            monitor_env = wrapper(monitor_env)
        monitor_env = monitor_env_wrapper(monitor_env)
        # auto_reset must be False to ensure Monitor works correctly
        monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

    if run_eval:
        eval_dir = os.path.join(load_root_dir, 'eval')
        n_envs = n_envs or num_eval_episodes
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(prefix='EvalMetrics',
                                           buffer_size=num_eval_episodes,
                                           batch_size=n_envs),
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='EvalMetrics',
                buffer_size=num_eval_episodes,
                batch_size=n_envs)
        ] + [
            tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name))
            for m in eval_metrics
        ]
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(env_name,
                                    gym_env_wrappers=gym_env_wrappers)
            ] * n_envs))
        if seed:
            seeds = [seed * n_envs + i for i in range(n_envs)]
            try:
                eval_tf_env.pyenv.seed(seeds)
            except:
                pass

    global_step = tf.compat.v1.train.get_or_create_global_step()

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=agents.normal_projection_net)

    critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)

    if agent_class in SAFETY_AGENTS:
        safety_critic_net = agents.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=critic_joint_fc_layers)
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               safety_critic_network=safety_critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)
    else:
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)

    collect_data_spec = tf_agent.collect_data_spec
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, batch_size=1, max_length=1000000)
    replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer)

    tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent)
    if agent_class in SAFETY_AGENTS:
        target_safety = target_safety or tf_agent._target_safety
    loaded_train_steps = global_step.numpy()
    logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir,
                 loaded_train_steps)
    global_step.assign(0)
    tf.summary.experimental.set_step(global_step)

    thresholds = [target_safety, 0.5]
    sc_metrics = [
        tf.keras.metrics.AUC(name='safety_critic_auc'),
        tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc',
                                        threshold=0.5),
        tf.keras.metrics.TruePositives(name='safety_critic_tp',
                                       thresholds=thresholds),
        tf.keras.metrics.FalsePositives(name='safety_critic_fp',
                                        thresholds=thresholds),
        tf.keras.metrics.TrueNegatives(name='safety_critic_tn',
                                       thresholds=thresholds),
        tf.keras.metrics.FalseNegatives(name='safety_critic_fn',
                                        thresholds=thresholds)
    ]

    if seed:
        tf.compat.v1.set_random_seed(seed)

    summaries_flush_secs = 10
    timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S')
    offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp)
    config_saver = gin.tf.GinConfigSaverHook(offline_train_dir,
                                             summarize_config=True)
    tf.function(config_saver.after_create_session)()

    sc_summary_writer = tf.compat.v2.summary.create_file_writer(
        offline_train_dir, flush_millis=summaries_flush_secs * 1000)
    sc_summary_writer.set_as_default()

    if safety_critic_kernel_scale is not None:
        ki = tf.compat.v1.variance_scaling_initializer(
            scale=safety_critic_kernel_scale,
            mode='fan_in',
            distribution='truncated_normal')
    else:
        ki = tf.compat.v1.keras.initializers.VarianceScaling(
            scale=1. / 3., mode='fan_in', distribution='uniform')

    if safety_critic_bias_init_val is not None:
        bi = tf.constant_initializer(safety_critic_bias_init_val)
    else:
        bi = None
    sc_net_off = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=safety_critic_joint_fc_layers,
        kernel_initializer=ki,
        value_bias_initializer=bi,
        name='SafetyCriticOffline')
    sc_net_off.create_variables()
    target_sc_net_off = common.maybe_copy_target_network_with_checks(
        sc_net_off, None, 'TargetSafetyCriticNetwork')
    optimizer = tf.keras.optimizers.Adam(safety_critic_lr)
    sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic')
    sc_checkpointer = common.Checkpointer(
        ckpt_dir=sc_net_off_ckpt_dir,
        safety_critic=sc_net_off,
        target_safety_critic=target_sc_net_off,
        optimizer=optimizer,
        global_step=global_step,
        max_to_keep=5)
    sc_checkpointer.initialize_or_restore()

    resample_counter = py_metrics.CounterMetric('ActionResampleCounter')
    eval_policy = agents.SafeActorPolicyRSVar(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=actor_net,
        safety_critic_network=sc_net_off,
        safety_threshold=target_safety,
        resample_counter=resample_counter,
        training=True)

    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       num_steps=2,
                                       sample_batch_size=batch_size //
                                       2).prefetch(3)
    data = iter(dataset)
    full_data = replay_buffer.gather_all()

    fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool)
    fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, fail_mask), full_data)
    init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data)
    before_fail_mask = tf.roll(fail_mask, [-1], axis=[1])
    after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1])
    before_fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data)
    after_init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, after_init_mask), full_data)

    filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask))
    filter_mask = tf.pad(
        filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]])
    n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy()

    failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec,
        batch_size=1,
        max_length=n_failures,
        dataset_window_shift=1)
    data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask)

    sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3,
                                               sample_batch_size=batch_size //
                                               2,
                                               num_steps=2).prefetch(3)
    neg_data = iter(sc_dataset_neg)

    get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0]
    eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step,
                                after_init_step, get_action)

    losses = []
    mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss')
    target_update = train_utils.get_target_updater(sc_net_off,
                                                   target_sc_net_off)

    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        while global_step.numpy() < num_global_steps:
            pos_experience, _ = next(data)
            neg_experience, _ = next(neg_data)
            exp = data_utils.concat_batches(pos_experience, neg_experience,
                                            collect_data_spec)
            boundary_mask = tf.logical_not(exp.is_boundary()[:, 0])
            exp = nest_utils.fast_map_structure(
                lambda *x: tf.boolean_mask(*x, boundary_mask), exp)
            safe_rew = exp.observation['task_agn_rew'][:, 1]
            if fail_weight:
                weights = tf.where(tf.cast(safe_rew, tf.bool),
                                   fail_weight / 0.5, (1 - fail_weight) / 0.5)
            else:
                weights = None
            train_loss, sc_loss, lam_loss = train_step(
                exp,
                safe_rew,
                tf_agent,
                sc_net=sc_net_off,
                target_sc_net=target_sc_net_off,
                metrics=sc_metrics,
                weights=weights,
                target_safety=target_safety,
                optimizer=optimizer,
                target_update=target_update,
                debug_summaries=debug_summaries)
            global_step.assign_add(1)
            global_step_val = global_step.numpy()
            losses.append(
                (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy()))
            mean_loss(train_loss)
            with tf.name_scope('Losses'):
                tf.compat.v2.summary.scalar(name='sc_loss',
                                            data=sc_loss,
                                            step=global_step_val)
                tf.compat.v2.summary.scalar(name='lam_loss',
                                            data=lam_loss,
                                            step=global_step_val)
                if global_step_val % summary_interval == 0:
                    tf.compat.v2.summary.scalar(name=mean_loss.name,
                                                data=mean_loss.result(),
                                                step=global_step_val)
            if global_step_val % summary_interval == 0:
                with tf.name_scope('Metrics'):
                    for metric in sc_metrics:
                        if len(tf.squeeze(metric.result()).shape) == 0:
                            tf.compat.v2.summary.scalar(name=metric.name,
                                                        data=metric.result(),
                                                        step=global_step_val)
                        else:
                            fmt_str = '_{}'.format(thresholds[0])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[0],
                                step=global_step_val)
                            fmt_str = '_{}'.format(thresholds[1])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[1],
                                step=global_step_val)
                        metric.reset_states()
            if global_step_val % eval_interval == 0:
                eval_sc(sc_net_off, step=global_step_val)
                if run_eval:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix='EvalMetrics',
                    )
                    if train_metrics_callback is not None:
                        train_metrics_callback(results, global_step_val)
                    metric_utils.log_metrics(eval_metrics)
                    with eval_summary_writer.as_default():
                        for eval_metric in eval_metrics[2:]:
                            eval_metric.tf_summaries(
                                train_step=global_step,
                                step_metrics=eval_metrics[:2])
            if monitor and global_step_val % monitor_interval == 0:
                monitor_time_step = monitor_py_env.reset()
                monitor_policy_state = eval_policy.get_initial_state(1)
                ep_len = 0
                monitor_start = time.time()
                while not monitor_time_step.is_last():
                    monitor_action = eval_policy.action(
                        monitor_time_step, monitor_policy_state)
                    action, monitor_policy_state = monitor_action.action, monitor_action.state
                    monitor_time_step = monitor_py_env.step(action)
                    ep_len += 1
                logging.debug(
                    'saved rollout at timestep %d, rollout length: %d, %4.2f sec',
                    global_step_val, ep_len,
                    time.time() - monitor_start)

            if global_step_val % train_checkpoint_interval == 0:
                sc_checkpointer.save(global_step=global_step_val)