Example #1
0
def filter_before_first_step(time_steps, actions=None):
  flat_time_steps = tf.nest.flatten(time_steps)
  flat_time_steps = [tf.unstack(time_step, axis=1) for time_step in
                     flat_time_steps]
  time_steps = [tf.nest.pack_sequence_as(time_steps, time_step) for time_step in
                zip(*flat_time_steps)]
  if actions is None:
    actions = [None] * len(time_steps)
  else:
    actions = tf.unstack(actions, axis=1)
  assert len(time_steps) == len(actions)

  time_steps = list(reversed(time_steps))
  actions = list(reversed(actions))
  filtered_time_steps = []
  filtered_actions = []
  for t, (time_step, action) in enumerate(zip(time_steps, actions)):
    if t == 0:
      reset_mask = tf.equal(time_step.step_type, ts.StepType.FIRST)
    else:
      time_step = tf.nest.map_structure(lambda x, y: tf.where(reset_mask, x, y),
                                        last_time_step, time_step)
      action = tf.where(reset_mask, tf.zeros_like(action),
                        action) if action is not None else None
    filtered_time_steps.append(time_step)
    filtered_actions.append(action)
    reset_mask = tf.logical_or(
        reset_mask,
        tf.equal(time_step.step_type, ts.StepType.FIRST))
    last_time_step = time_step
  filtered_time_steps = list(reversed(filtered_time_steps))
  filtered_actions = list(reversed(filtered_actions))

  filtered_flat_time_steps = [tf.nest.flatten(time_step) for time_step in
                              filtered_time_steps]
  filtered_flat_time_steps = [tf.stack(time_step, axis=1) for time_step in
                              zip(*filtered_flat_time_steps)]
  filtered_time_steps = tf.nest.pack_sequence_as(filtered_time_steps[0],
                                                 filtered_flat_time_steps)
  if action is None:
    return filtered_time_steps
  else:
    actions = tf.stack(filtered_actions, axis=1)
    return filtered_time_steps, actions
Example #2
0
def train_eval(
        load_root_dir,
        env_load_fn=None,
        gym_env_wrappers=[],
        monitor=False,
        env_name=None,
        agent_class=None,
        train_metrics_callback=None,
        # SacAgent args
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256),
        # Safety Critic training args
        safety_critic_joint_fc_layers=None,
        safety_critic_lr=3e-4,
        safety_critic_bias_init_val=None,
        safety_critic_kernel_scale=None,
        n_envs=None,
        target_safety=0.2,
        fail_weight=None,
        # Params for train
        num_global_steps=10000,
        batch_size=256,
        # Params for eval
        run_eval=False,
        eval_metrics=[],
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        summary_interval=1000,
        monitor_interval=5000,
        summaries_flush_secs=10,
        debug_summaries=False,
        seed=None):

    if isinstance(agent_class, str):
        assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(
            agent_class)
        agent_class = ALGOS.get(agent_class)

    train_ckpt_dir = osp.join(load_root_dir, 'train')
    rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer')

    py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)

    if monitor:
        vid_path = os.path.join(load_root_dir, 'rollouts')
        monitor_env_wrapper = misc.monitor_freq(1, vid_path)
        monitor_env = gym.make(env_name)
        for wrapper in gym_env_wrappers:
            monitor_env = wrapper(monitor_env)
        monitor_env = monitor_env_wrapper(monitor_env)
        # auto_reset must be False to ensure Monitor works correctly
        monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

    if run_eval:
        eval_dir = os.path.join(load_root_dir, 'eval')
        n_envs = n_envs or num_eval_episodes
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(prefix='EvalMetrics',
                                           buffer_size=num_eval_episodes,
                                           batch_size=n_envs),
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='EvalMetrics',
                buffer_size=num_eval_episodes,
                batch_size=n_envs)
        ] + [
            tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name))
            for m in eval_metrics
        ]
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(env_name,
                                    gym_env_wrappers=gym_env_wrappers)
            ] * n_envs))
        if seed:
            seeds = [seed * n_envs + i for i in range(n_envs)]
            try:
                eval_tf_env.pyenv.seed(seeds)
            except:
                pass

    global_step = tf.compat.v1.train.get_or_create_global_step()

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=agents.normal_projection_net)

    critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)

    if agent_class in SAFETY_AGENTS:
        safety_critic_net = agents.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=critic_joint_fc_layers)
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               safety_critic_network=safety_critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)
    else:
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)

    collect_data_spec = tf_agent.collect_data_spec
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, batch_size=1, max_length=1000000)
    replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer)

    tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent)
    if agent_class in SAFETY_AGENTS:
        target_safety = target_safety or tf_agent._target_safety
    loaded_train_steps = global_step.numpy()
    logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir,
                 loaded_train_steps)
    global_step.assign(0)
    tf.summary.experimental.set_step(global_step)

    thresholds = [target_safety, 0.5]
    sc_metrics = [
        tf.keras.metrics.AUC(name='safety_critic_auc'),
        tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc',
                                        threshold=0.5),
        tf.keras.metrics.TruePositives(name='safety_critic_tp',
                                       thresholds=thresholds),
        tf.keras.metrics.FalsePositives(name='safety_critic_fp',
                                        thresholds=thresholds),
        tf.keras.metrics.TrueNegatives(name='safety_critic_tn',
                                       thresholds=thresholds),
        tf.keras.metrics.FalseNegatives(name='safety_critic_fn',
                                        thresholds=thresholds)
    ]

    if seed:
        tf.compat.v1.set_random_seed(seed)

    summaries_flush_secs = 10
    timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S')
    offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp)
    config_saver = gin.tf.GinConfigSaverHook(offline_train_dir,
                                             summarize_config=True)
    tf.function(config_saver.after_create_session)()

    sc_summary_writer = tf.compat.v2.summary.create_file_writer(
        offline_train_dir, flush_millis=summaries_flush_secs * 1000)
    sc_summary_writer.set_as_default()

    if safety_critic_kernel_scale is not None:
        ki = tf.compat.v1.variance_scaling_initializer(
            scale=safety_critic_kernel_scale,
            mode='fan_in',
            distribution='truncated_normal')
    else:
        ki = tf.compat.v1.keras.initializers.VarianceScaling(
            scale=1. / 3., mode='fan_in', distribution='uniform')

    if safety_critic_bias_init_val is not None:
        bi = tf.constant_initializer(safety_critic_bias_init_val)
    else:
        bi = None
    sc_net_off = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=safety_critic_joint_fc_layers,
        kernel_initializer=ki,
        value_bias_initializer=bi,
        name='SafetyCriticOffline')
    sc_net_off.create_variables()
    target_sc_net_off = common.maybe_copy_target_network_with_checks(
        sc_net_off, None, 'TargetSafetyCriticNetwork')
    optimizer = tf.keras.optimizers.Adam(safety_critic_lr)
    sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic')
    sc_checkpointer = common.Checkpointer(
        ckpt_dir=sc_net_off_ckpt_dir,
        safety_critic=sc_net_off,
        target_safety_critic=target_sc_net_off,
        optimizer=optimizer,
        global_step=global_step,
        max_to_keep=5)
    sc_checkpointer.initialize_or_restore()

    resample_counter = py_metrics.CounterMetric('ActionResampleCounter')
    eval_policy = agents.SafeActorPolicyRSVar(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=actor_net,
        safety_critic_network=sc_net_off,
        safety_threshold=target_safety,
        resample_counter=resample_counter,
        training=True)

    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       num_steps=2,
                                       sample_batch_size=batch_size //
                                       2).prefetch(3)
    data = iter(dataset)
    full_data = replay_buffer.gather_all()

    fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool)
    fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, fail_mask), full_data)
    init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data)
    before_fail_mask = tf.roll(fail_mask, [-1], axis=[1])
    after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1])
    before_fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data)
    after_init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, after_init_mask), full_data)

    filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask))
    filter_mask = tf.pad(
        filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]])
    n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy()

    failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec,
        batch_size=1,
        max_length=n_failures,
        dataset_window_shift=1)
    data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask)

    sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3,
                                               sample_batch_size=batch_size //
                                               2,
                                               num_steps=2).prefetch(3)
    neg_data = iter(sc_dataset_neg)

    get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0]
    eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step,
                                after_init_step, get_action)

    losses = []
    mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss')
    target_update = train_utils.get_target_updater(sc_net_off,
                                                   target_sc_net_off)

    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        while global_step.numpy() < num_global_steps:
            pos_experience, _ = next(data)
            neg_experience, _ = next(neg_data)
            exp = data_utils.concat_batches(pos_experience, neg_experience,
                                            collect_data_spec)
            boundary_mask = tf.logical_not(exp.is_boundary()[:, 0])
            exp = nest_utils.fast_map_structure(
                lambda *x: tf.boolean_mask(*x, boundary_mask), exp)
            safe_rew = exp.observation['task_agn_rew'][:, 1]
            if fail_weight:
                weights = tf.where(tf.cast(safe_rew, tf.bool),
                                   fail_weight / 0.5, (1 - fail_weight) / 0.5)
            else:
                weights = None
            train_loss, sc_loss, lam_loss = train_step(
                exp,
                safe_rew,
                tf_agent,
                sc_net=sc_net_off,
                target_sc_net=target_sc_net_off,
                metrics=sc_metrics,
                weights=weights,
                target_safety=target_safety,
                optimizer=optimizer,
                target_update=target_update,
                debug_summaries=debug_summaries)
            global_step.assign_add(1)
            global_step_val = global_step.numpy()
            losses.append(
                (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy()))
            mean_loss(train_loss)
            with tf.name_scope('Losses'):
                tf.compat.v2.summary.scalar(name='sc_loss',
                                            data=sc_loss,
                                            step=global_step_val)
                tf.compat.v2.summary.scalar(name='lam_loss',
                                            data=lam_loss,
                                            step=global_step_val)
                if global_step_val % summary_interval == 0:
                    tf.compat.v2.summary.scalar(name=mean_loss.name,
                                                data=mean_loss.result(),
                                                step=global_step_val)
            if global_step_val % summary_interval == 0:
                with tf.name_scope('Metrics'):
                    for metric in sc_metrics:
                        if len(tf.squeeze(metric.result()).shape) == 0:
                            tf.compat.v2.summary.scalar(name=metric.name,
                                                        data=metric.result(),
                                                        step=global_step_val)
                        else:
                            fmt_str = '_{}'.format(thresholds[0])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[0],
                                step=global_step_val)
                            fmt_str = '_{}'.format(thresholds[1])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[1],
                                step=global_step_val)
                        metric.reset_states()
            if global_step_val % eval_interval == 0:
                eval_sc(sc_net_off, step=global_step_val)
                if run_eval:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix='EvalMetrics',
                    )
                    if train_metrics_callback is not None:
                        train_metrics_callback(results, global_step_val)
                    metric_utils.log_metrics(eval_metrics)
                    with eval_summary_writer.as_default():
                        for eval_metric in eval_metrics[2:]:
                            eval_metric.tf_summaries(
                                train_step=global_step,
                                step_metrics=eval_metrics[:2])
            if monitor and global_step_val % monitor_interval == 0:
                monitor_time_step = monitor_py_env.reset()
                monitor_policy_state = eval_policy.get_initial_state(1)
                ep_len = 0
                monitor_start = time.time()
                while not monitor_time_step.is_last():
                    monitor_action = eval_policy.action(
                        monitor_time_step, monitor_policy_state)
                    action, monitor_policy_state = monitor_action.action, monitor_action.state
                    monitor_time_step = monitor_py_env.step(action)
                    ep_len += 1
                logging.debug(
                    'saved rollout at timestep %d, rollout length: %d, %4.2f sec',
                    global_step_val, ep_len,
                    time.time() - monitor_start)

            if global_step_val % train_checkpoint_interval == 0:
                sc_checkpointer.save(global_step=global_step_val)
Example #3
0
def prepare_scannet_scene_dataset(inputs, valid_object_classes=None):
  """Maps the fields from loaded input to standard fields.

  Args:
    inputs: A dictionary of input tensors.
    valid_object_classes: List of valid object classes. if None, it is ignored.

  Returns:
    A dictionary of input tensors with standard field names.
  """
  prepared_inputs = {}
  if 'mesh/vertices/positions' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .point_positions] = inputs['mesh/vertices/positions']
  if 'mesh/vertices/normals' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .point_normals] = inputs['mesh/vertices/normals']
    prepared_inputs[standard_fields.InputDataFields.point_normals] = tf.where(
        tf.math.is_nan(
            prepared_inputs[standard_fields.InputDataFields.point_normals]),
        tf.zeros_like(
            prepared_inputs[standard_fields.InputDataFields.point_normals]),
        prepared_inputs[standard_fields.InputDataFields.point_normals])
  if 'mesh/vertices/colors' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .point_colors] = inputs['mesh/vertices/colors'][:, 0:3]
    prepared_inputs[standard_fields.InputDataFields.point_colors] = tf.cast(
        prepared_inputs[standard_fields.InputDataFields.point_colors],
        dtype=tf.float32)
    prepared_inputs[standard_fields.InputDataFields.point_colors] *= (2.0 /
                                                                      255.0)
    prepared_inputs[standard_fields.InputDataFields.point_colors] -= 1.0
  if 'scene_name' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .camera_image_name] = inputs['scene_name']
  if 'mesh/vertices/semantic_labels' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields
        .object_class_points] = inputs['mesh/vertices/semantic_labels']
  if 'mesh/vertices/instance_labels' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.object_instance_id_points] = tf.reshape(
            inputs['mesh/vertices/instance_labels'], [-1])

  if valid_object_classes is not None:
    valid_objects_mask = tf.cast(
        tf.zeros_like(
            prepared_inputs[
                standard_fields.InputDataFields.object_class_points],
            dtype=tf.int32),
        dtype=tf.bool)
    for object_class in valid_object_classes:
      valid_objects_mask = tf.logical_or(
          valid_objects_mask,
          tf.equal(
              prepared_inputs[
                  standard_fields.InputDataFields.object_class_points],
              object_class))
    valid_objects_mask = tf.cast(
        valid_objects_mask,
        dtype=prepared_inputs[
            standard_fields.InputDataFields.object_class_points].dtype)
    prepared_inputs[standard_fields.InputDataFields
                    .object_class_points] *= valid_objects_mask
  return prepared_inputs
Example #4
0
def prepare_kitti_dataset(inputs, valid_object_classes=None):
  """Maps the fields from loaded input to standard fields.

  Args:
    inputs: A dictionary of input tensors.
    valid_object_classes: List of valid object classes. if None, it is ignored.

  Returns:
    A dictionary of input tensors with standard field names.
  """
  prepared_inputs = {}
  prepared_inputs[standard_fields.InputDataFields.point_positions] = inputs[
      standard_fields.InputDataFields.point_positions]
  prepared_inputs[standard_fields.InputDataFields.point_intensities] = inputs[
      standard_fields.InputDataFields.point_intensities]
  prepared_inputs[standard_fields.InputDataFields
                  .camera_intrinsics] = inputs['cameras/cam02/intrinsics/K']
  prepared_inputs[standard_fields.InputDataFields.
                  camera_rotation_matrix] = inputs['cameras/cam02/extrinsics/R']
  prepared_inputs[standard_fields.InputDataFields
                  .camera_translation] = inputs['cameras/cam02/extrinsics/t']
  prepared_inputs[standard_fields.InputDataFields
                  .camera_image] = inputs['cameras/cam02/image']
  prepared_inputs[standard_fields.InputDataFields
                  .camera_raw_image] = inputs['cameras/cam02/image']
  prepared_inputs[standard_fields.InputDataFields
                  .camera_original_image] = inputs['cameras/cam02/image']
  if 'scene_name' in inputs and 'frame_name' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.camera_image_name] = tf.strings.join(
            [inputs['scene_name'], inputs['frame_name']], separator='_')
  if 'objects/pose/R' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .objects_rotation_matrix] = inputs['objects/pose/R']
  if 'objects/pose/t' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .objects_center] = inputs['objects/pose/t']
  if 'objects/shape/dimension' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.objects_length] = tf.reshape(
            inputs['objects/shape/dimension'][:, 0], [-1, 1])
    prepared_inputs[standard_fields.InputDataFields.objects_width] = tf.reshape(
        inputs['objects/shape/dimension'][:, 1], [-1, 1])
    prepared_inputs[
        standard_fields.InputDataFields.objects_height] = tf.reshape(
            inputs['objects/shape/dimension'][:, 2], [-1, 1])
  if 'objects/category/label' in inputs:
    prepared_inputs[standard_fields.InputDataFields.objects_class] = tf.reshape(
        inputs['objects/category/label'], [-1, 1])
  if valid_object_classes is not None:
    valid_objects_mask = tf.cast(
        tf.zeros_like(
            prepared_inputs[standard_fields.InputDataFields.objects_class],
            dtype=tf.int32),
        dtype=tf.bool)
    for object_class in valid_object_classes:
      valid_objects_mask = tf.logical_or(
          valid_objects_mask,
          tf.equal(
              prepared_inputs[standard_fields.InputDataFields.objects_class],
              object_class))
    valid_objects_mask = tf.reshape(valid_objects_mask, [-1])
    for key in standard_fields.get_input_object_fields():
      if key in prepared_inputs:
        prepared_inputs[key] = tf.boolean_mask(prepared_inputs[key],
                                               valid_objects_mask)

  return prepared_inputs
Example #5
0
def prepare_waymo_open_dataset(inputs,
                               valid_object_classes=None,
                               max_object_distance_from_source=74.88):
  """Maps the fields from loaded input to standard fields.

  Args:
    inputs: A dictionary of input tensors.
    valid_object_classes: List of valid object classes. if None, it is ignored.
    max_object_distance_from_source: Maximum distance of objects from source. It
      will be ignored if None.

  Returns:
    A dictionary of input tensors with standard field names.
  """
  prepared_inputs = {}
  if standard_fields.InputDataFields.point_positions in inputs:
    prepared_inputs[standard_fields.InputDataFields.point_positions] = inputs[
        standard_fields.InputDataFields.point_positions]
  if standard_fields.InputDataFields.point_intensities in inputs:
    prepared_inputs[standard_fields.InputDataFields.point_intensities] = inputs[
        standard_fields.InputDataFields.point_intensities]
  if standard_fields.InputDataFields.point_elongations in inputs:
    prepared_inputs[standard_fields.InputDataFields.point_elongations] = inputs[
        standard_fields.InputDataFields.point_elongations]
  if standard_fields.InputDataFields.point_normals in inputs:
    prepared_inputs[standard_fields.InputDataFields.point_normals] = inputs[
        standard_fields.InputDataFields.point_normals]
  if 'cameras/front/intrinsics/K' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .camera_intrinsics] = inputs['cameras/front/intrinsics/K']
  if 'cameras/front/extrinsics/R' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields
        .camera_rotation_matrix] = inputs['cameras/front/extrinsics/R']
  if 'cameras/front/extrinsics/t' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .camera_translation] = inputs['cameras/front/extrinsics/t']
  if 'cameras/front/image' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .camera_image] = inputs['cameras/front/image']
    prepared_inputs[standard_fields.InputDataFields
                    .camera_raw_image] = inputs['cameras/front/image']
    prepared_inputs[standard_fields.InputDataFields
                    .camera_original_image] = inputs['cameras/front/image']
  if 'scene_name' in inputs and 'frame_name' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.camera_image_name] = tf.strings.join(
            [inputs['scene_name'], inputs['frame_name']], separator='_')
  if 'objects/pose/R' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .objects_rotation_matrix] = inputs['objects/pose/R']
  if 'objects/pose/t' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .objects_center] = inputs['objects/pose/t']
  if 'objects/shape/dimension' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.objects_length] = tf.reshape(
            inputs['objects/shape/dimension'][:, 0], [-1, 1])
    prepared_inputs[standard_fields.InputDataFields.objects_width] = tf.reshape(
        inputs['objects/shape/dimension'][:, 1], [-1, 1])
    prepared_inputs[
        standard_fields.InputDataFields.objects_height] = tf.reshape(
            inputs['objects/shape/dimension'][:, 2], [-1, 1])
  if 'objects/category/label' in inputs:
    prepared_inputs[standard_fields.InputDataFields.objects_class] = tf.reshape(
        inputs['objects/category/label'], [-1, 1])
  if valid_object_classes is not None:
    valid_objects_mask = tf.cast(
        tf.zeros_like(
            prepared_inputs[standard_fields.InputDataFields.objects_class],
            dtype=tf.int32),
        dtype=tf.bool)
    for object_class in valid_object_classes:
      valid_objects_mask = tf.logical_or(
          valid_objects_mask,
          tf.equal(
              prepared_inputs[standard_fields.InputDataFields.objects_class],
              object_class))
    valid_objects_mask = tf.reshape(valid_objects_mask, [-1])
    for key in standard_fields.get_input_object_fields():
      if key in prepared_inputs:
        prepared_inputs[key] = tf.boolean_mask(prepared_inputs[key],
                                               valid_objects_mask)

  if max_object_distance_from_source is not None:
    if standard_fields.InputDataFields.objects_center in prepared_inputs:
      object_distances = tf.norm(
          prepared_inputs[standard_fields.InputDataFields.objects_center][:,
                                                                          0:2],
          axis=1)
      valid_mask = tf.less(object_distances, max_object_distance_from_source)
      for key in standard_fields.get_input_object_fields():
        if key in prepared_inputs:
          prepared_inputs[key] = tf.boolean_mask(prepared_inputs[key],
                                                 valid_mask)

  return prepared_inputs
Example #6
0
def prepare_scannet_frame_dataset(inputs,
                                  min_pixel_depth=0.3,
                                  max_pixel_depth=6.0,
                                  valid_object_classes=None):
  """Maps the fields from loaded input to standard fields.

  Args:
    inputs: A dictionary of input tensors.
    min_pixel_depth: Pixels with depth values less than this are pruned.
    max_pixel_depth: Pixels with depth values more than this are pruned.
    valid_object_classes: List of valid object classes. if None, it is ignored.

  Returns:
    A dictionary of input tensors with standard field names.
  """
  prepared_inputs = {}
  if 'cameras/rgbd_camera/intrinsics/K' not in inputs:
    raise ValueError('Intrinsic matrix is missing.')
  if 'cameras/rgbd_camera/extrinsics/R' not in inputs:
    raise ValueError('Extrinsic rotation matrix is missing.')
  if 'cameras/rgbd_camera/extrinsics/t' not in inputs:
    raise ValueError('Extrinsics translation is missing.')
  if 'cameras/rgbd_camera/depth_image' not in inputs:
    raise ValueError('Depth image is missing.')
  if 'cameras/rgbd_camera/color_image' not in inputs:
    raise ValueError('Color image is missing.')
  if 'frame_name' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .camera_image_name] = inputs['frame_name']
  camera_intrinsics = inputs['cameras/rgbd_camera/intrinsics/K']
  depth_image = inputs['cameras/rgbd_camera/depth_image']
  image_height = tf.shape(depth_image)[0]
  image_width = tf.shape(depth_image)[1]
  x, y = tf.meshgrid(
      tf.range(image_width), tf.range(image_height), indexing='xy')
  x = tf.reshape(tf.cast(x, dtype=tf.float32) + 0.5, [-1, 1])
  y = tf.reshape(tf.cast(y, dtype=tf.float32) + 0.5, [-1, 1])
  point_positions = projections.image_frame_to_camera_frame(
      image_frame=tf.concat([x, y], axis=1),
      camera_intrinsics=camera_intrinsics)
  rotate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/R']
  translate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/t']
  point_positions = projections.to_world_frame(
      camera_frame_points=point_positions,
      rotate_world_to_camera=rotate_world_to_camera,
      translate_world_to_camera=translate_world_to_camera)
  prepared_inputs[standard_fields.InputDataFields
                  .point_positions] = point_positions * tf.reshape(
                      depth_image, [-1, 1])
  depth_values = tf.reshape(depth_image, [-1])
  valid_depth_mask = tf.logical_and(
      tf.greater_equal(depth_values, min_pixel_depth),
      tf.less_equal(depth_values, max_pixel_depth))
  prepared_inputs[standard_fields.InputDataFields.point_colors] = tf.reshape(
      tf.cast(inputs['cameras/rgbd_camera/color_image'], dtype=tf.float32),
      [-1, 3])
  prepared_inputs[standard_fields.InputDataFields.point_colors] *= (2.0 / 255.0)
  prepared_inputs[standard_fields.InputDataFields.point_colors] -= 1.0
  prepared_inputs[
      standard_fields.InputDataFields.point_positions] = tf.boolean_mask(
          prepared_inputs[standard_fields.InputDataFields.point_positions],
          valid_depth_mask)
  prepared_inputs[
      standard_fields.InputDataFields.point_colors] = tf.boolean_mask(
          prepared_inputs[standard_fields.InputDataFields.point_colors],
          valid_depth_mask)
  if 'cameras/rgbd_camera/semantic_image' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.object_class_points] = tf.cast(
            tf.reshape(inputs['cameras/rgbd_camera/semantic_image'], [-1, 1]),
            dtype=tf.int32)
    prepared_inputs[
        standard_fields.InputDataFields.object_class_points] = tf.boolean_mask(
            prepared_inputs[
                standard_fields.InputDataFields.object_class_points],
            valid_depth_mask)
  if 'cameras/rgbd_camera/instance_image' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.object_instance_id_points] = tf.cast(
            tf.reshape(inputs['cameras/rgbd_camera/instance_image'], [-1]),
            dtype=tf.int32)
    prepared_inputs[standard_fields.InputDataFields
                    .object_instance_id_points] = tf.boolean_mask(
                        prepared_inputs[standard_fields.InputDataFields
                                        .object_instance_id_points],
                        valid_depth_mask)

  if valid_object_classes is not None:
    valid_objects_mask = tf.cast(
        tf.zeros_like(
            prepared_inputs[
                standard_fields.InputDataFields.object_class_points],
            dtype=tf.int32),
        dtype=tf.bool)
    for object_class in valid_object_classes:
      valid_objects_mask = tf.logical_or(
          valid_objects_mask,
          tf.equal(
              prepared_inputs[
                  standard_fields.InputDataFields.object_class_points],
              object_class))
    valid_objects_mask = tf.cast(
        valid_objects_mask,
        dtype=prepared_inputs[
            standard_fields.InputDataFields.object_class_points].dtype)
    prepared_inputs[standard_fields.InputDataFields
                    .object_class_points] *= valid_objects_mask
  return prepared_inputs