Ejemplo n.º 1
0
def pick_labeled_image(mesh_inputs, view_image_inputs, view_indices_2d_inputs,
                       view_name):
    """Pick the image with most number of labeled points projecting to it."""
    if view_name not in view_image_inputs:
        return
    if view_name not in view_indices_2d_inputs:
        return
    if standard_fields.InputDataFields.point_loss_weights not in mesh_inputs:
        raise ValueError('The key `weights` is missing from mesh_inputs.')
    height = tf.shape(view_image_inputs[view_name])[1]
    width = tf.shape(view_image_inputs[view_name])[2]
    valid_points_y = tf.logical_and(
        tf.greater_equal(view_indices_2d_inputs[view_name][:, :, 0], 0),
        tf.less(view_indices_2d_inputs[view_name][:, :, 0], height))
    valid_points_x = tf.logical_and(
        tf.greater_equal(view_indices_2d_inputs[view_name][:, :, 1], 0),
        tf.less(view_indices_2d_inputs[view_name][:, :, 1], width))
    valid_points = tf.logical_and(valid_points_y, valid_points_x)
    image_total_weights = tf.reduce_sum(
        tf.cast(valid_points, dtype=tf.float32) * tf.squeeze(
            mesh_inputs[standard_fields.InputDataFields.point_loss_weights],
            axis=1),
        axis=1)
    image_total_weights = tf.cond(
        tf.equal(tf.reduce_sum(image_total_weights), 0),
        lambda: tf.reduce_sum(tf.cast(valid_points, dtype=tf.float32), axis=1),
        lambda: image_total_weights)
    best_image = tf.math.argmax(image_total_weights)
    view_image_inputs[view_name] = view_image_inputs[view_name][
        best_image:best_image + 1, :, :, :]
    view_indices_2d_inputs[view_name] = view_indices_2d_inputs[view_name][
        best_image:best_image + 1, :, :]
Ejemplo n.º 2
0
def safety_critic_loss(time_steps,
                       actions,
                       next_time_steps,
                       safety_rewards,
                       get_action,
                       global_step,
                       critic_network=None,
                       target_network=None,
                       target_safety=None,
                       safety_gamma=0.45,
                       loss_fn='bce',
                       metrics=None,
                       debug_summaries=False):
    """Computes the critic loss for SAC training.

  Args:
    time_steps: A batch of timesteps.
    actions: A batch of actions.
    next_time_steps: A batch of next timesteps.
    safety_rewards: Task-agnostic rewards for safety. 1 is unsafe, 0 is safe.
    weights: Optional scalar or elementwise (per-batch-entry) importance
      weights.

  Returns:
    safe_critic_loss: A scalar critic loss.
  """
    with tf.name_scope('safety_critic_loss'):
        next_actions = get_action(next_time_steps)
        target_input = (next_time_steps.observation, next_actions)
        target_q_values, _ = target_network(target_input,
                                            next_time_steps.step_type)
        target_q_values = tf.nn.sigmoid(target_q_values)
        td_targets = tf.stop_gradient(safety_rewards + (1 - safety_rewards) *
                                      safety_gamma * next_time_steps.discount *
                                      target_q_values)

        if loss_fn == 'bce' or loss_fn == tf.keras.losses.binary_crossentropy:
            td_targets = tf.nn.sigmoid(td_targets)

        pred_input = (time_steps.observation, actions)
        pred_td_targets, _ = critic_network(pred_input,
                                            time_steps.step_type,
                                            training=True)
        pred_td_targets = tf.nn.sigmoid(pred_td_targets)

        # Loss fns: binary_crossentropy/squared_difference
        if loss_fn == 'mse':
            sc_loss = tf.math.squared_difference(td_targets, pred_td_targets)
        elif loss_fn == 'bce' or loss_fn is None:
            sc_loss = tf.keras.losses.binary_crossentropy(
                td_targets, pred_td_targets)
        elif loss_fn is not None:
            sc_loss = loss_fn(td_targets, pred_td_targets)

        if metrics:
            for metric in metrics:
                if isinstance(metric, tf.keras.metrics.AUC):
                    metric.update_state(safety_rewards, pred_td_targets)
                else:
                    rew_pred = tf.greater_equal(pred_td_targets, target_safety)
                    metric.update_state(safety_rewards, rew_pred)

        if debug_summaries:
            pred_td_targets = tf.nn.sigmoid(pred_td_targets)
            td_errors = td_targets - pred_td_targets
            common.generate_tensor_summaries('safety_td_errors', td_errors,
                                             global_step)
            common.generate_tensor_summaries('safety_td_targets', td_targets,
                                             global_step)
            common.generate_tensor_summaries('safety_pred_td_targets',
                                             pred_td_targets, global_step)

        return sc_loss
Ejemplo n.º 3
0
def prepare_scannet_frame_dataset(inputs,
                                  min_pixel_depth=0.3,
                                  max_pixel_depth=6.0,
                                  valid_object_classes=None):
  """Maps the fields from loaded input to standard fields.

  Args:
    inputs: A dictionary of input tensors.
    min_pixel_depth: Pixels with depth values less than this are pruned.
    max_pixel_depth: Pixels with depth values more than this are pruned.
    valid_object_classes: List of valid object classes. if None, it is ignored.

  Returns:
    A dictionary of input tensors with standard field names.
  """
  prepared_inputs = {}
  if 'cameras/rgbd_camera/intrinsics/K' not in inputs:
    raise ValueError('Intrinsic matrix is missing.')
  if 'cameras/rgbd_camera/extrinsics/R' not in inputs:
    raise ValueError('Extrinsic rotation matrix is missing.')
  if 'cameras/rgbd_camera/extrinsics/t' not in inputs:
    raise ValueError('Extrinsics translation is missing.')
  if 'cameras/rgbd_camera/depth_image' not in inputs:
    raise ValueError('Depth image is missing.')
  if 'cameras/rgbd_camera/color_image' not in inputs:
    raise ValueError('Color image is missing.')
  if 'frame_name' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .camera_image_name] = inputs['frame_name']
  camera_intrinsics = inputs['cameras/rgbd_camera/intrinsics/K']
  depth_image = inputs['cameras/rgbd_camera/depth_image']
  image_height = tf.shape(depth_image)[0]
  image_width = tf.shape(depth_image)[1]
  x, y = tf.meshgrid(
      tf.range(image_width), tf.range(image_height), indexing='xy')
  x = tf.reshape(tf.cast(x, dtype=tf.float32) + 0.5, [-1, 1])
  y = tf.reshape(tf.cast(y, dtype=tf.float32) + 0.5, [-1, 1])
  point_positions = projections.image_frame_to_camera_frame(
      image_frame=tf.concat([x, y], axis=1),
      camera_intrinsics=camera_intrinsics)
  rotate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/R']
  translate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/t']
  point_positions = projections.to_world_frame(
      camera_frame_points=point_positions,
      rotate_world_to_camera=rotate_world_to_camera,
      translate_world_to_camera=translate_world_to_camera)
  prepared_inputs[standard_fields.InputDataFields
                  .point_positions] = point_positions * tf.reshape(
                      depth_image, [-1, 1])
  depth_values = tf.reshape(depth_image, [-1])
  valid_depth_mask = tf.logical_and(
      tf.greater_equal(depth_values, min_pixel_depth),
      tf.less_equal(depth_values, max_pixel_depth))
  prepared_inputs[standard_fields.InputDataFields.point_colors] = tf.reshape(
      tf.cast(inputs['cameras/rgbd_camera/color_image'], dtype=tf.float32),
      [-1, 3])
  prepared_inputs[standard_fields.InputDataFields.point_colors] *= (2.0 / 255.0)
  prepared_inputs[standard_fields.InputDataFields.point_colors] -= 1.0
  prepared_inputs[
      standard_fields.InputDataFields.point_positions] = tf.boolean_mask(
          prepared_inputs[standard_fields.InputDataFields.point_positions],
          valid_depth_mask)
  prepared_inputs[
      standard_fields.InputDataFields.point_colors] = tf.boolean_mask(
          prepared_inputs[standard_fields.InputDataFields.point_colors],
          valid_depth_mask)
  if 'cameras/rgbd_camera/semantic_image' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.object_class_points] = tf.cast(
            tf.reshape(inputs['cameras/rgbd_camera/semantic_image'], [-1, 1]),
            dtype=tf.int32)
    prepared_inputs[
        standard_fields.InputDataFields.object_class_points] = tf.boolean_mask(
            prepared_inputs[
                standard_fields.InputDataFields.object_class_points],
            valid_depth_mask)
  if 'cameras/rgbd_camera/instance_image' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.object_instance_id_points] = tf.cast(
            tf.reshape(inputs['cameras/rgbd_camera/instance_image'], [-1]),
            dtype=tf.int32)
    prepared_inputs[standard_fields.InputDataFields
                    .object_instance_id_points] = tf.boolean_mask(
                        prepared_inputs[standard_fields.InputDataFields
                                        .object_instance_id_points],
                        valid_depth_mask)

  if valid_object_classes is not None:
    valid_objects_mask = tf.cast(
        tf.zeros_like(
            prepared_inputs[
                standard_fields.InputDataFields.object_class_points],
            dtype=tf.int32),
        dtype=tf.bool)
    for object_class in valid_object_classes:
      valid_objects_mask = tf.logical_or(
          valid_objects_mask,
          tf.equal(
              prepared_inputs[
                  standard_fields.InputDataFields.object_class_points],
              object_class))
    valid_objects_mask = tf.cast(
        valid_objects_mask,
        dtype=prepared_inputs[
            standard_fields.InputDataFields.object_class_points].dtype)
    prepared_inputs[standard_fields.InputDataFields
                    .object_class_points] *= valid_objects_mask
  return prepared_inputs