def filter_before_first_step(time_steps, actions=None): flat_time_steps = tf.nest.flatten(time_steps) flat_time_steps = [tf.unstack(time_step, axis=1) for time_step in flat_time_steps] time_steps = [tf.nest.pack_sequence_as(time_steps, time_step) for time_step in zip(*flat_time_steps)] if actions is None: actions = [None] * len(time_steps) else: actions = tf.unstack(actions, axis=1) assert len(time_steps) == len(actions) time_steps = list(reversed(time_steps)) actions = list(reversed(actions)) filtered_time_steps = [] filtered_actions = [] for t, (time_step, action) in enumerate(zip(time_steps, actions)): if t == 0: reset_mask = tf.equal(time_step.step_type, ts.StepType.FIRST) else: time_step = tf.nest.map_structure(lambda x, y: tf.where(reset_mask, x, y), last_time_step, time_step) action = tf.where(reset_mask, tf.zeros_like(action), action) if action is not None else None filtered_time_steps.append(time_step) filtered_actions.append(action) reset_mask = tf.logical_or( reset_mask, tf.equal(time_step.step_type, ts.StepType.FIRST)) last_time_step = time_step filtered_time_steps = list(reversed(filtered_time_steps)) filtered_actions = list(reversed(filtered_actions)) filtered_flat_time_steps = [tf.nest.flatten(time_step) for time_step in filtered_time_steps] filtered_flat_time_steps = [tf.stack(time_step, axis=1) for time_step in zip(*filtered_flat_time_steps)] filtered_time_steps = tf.nest.pack_sequence_as(filtered_time_steps[0], filtered_flat_time_steps) if action is None: return filtered_time_steps else: actions = tf.stack(filtered_actions, axis=1) return filtered_time_steps, actions
def train_eval( load_root_dir, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args safety_critic_joint_fc_layers=None, safety_critic_lr=3e-4, safety_critic_bias_init_val=None, safety_critic_kernel_scale=None, n_envs=None, target_safety=0.2, fail_weight=None, # Params for train num_global_steps=10000, batch_size=256, # Params for eval run_eval=False, eval_metrics=[], num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, debug_summaries=False, seed=None): if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format( agent_class) agent_class = ALGOS.get(agent_class) train_ckpt_dir = osp.join(load_root_dir, 'train') rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer') py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(py_env) if monitor: vid_path = os.path.join(load_root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) if run_eval: eval_dir = os.path.join(load_root_dir, 'eval') n_envs = n_envs or num_eval_episodes eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric( prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs) ] + [ tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name)) for m in eval_metrics ] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: eval_tf_env.pyenv.seed(seeds) except: pass global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=False) else: tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=False) collect_data_spec = tf_agent.collect_data_spec replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=1000000) replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer) tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent) if agent_class in SAFETY_AGENTS: target_safety = target_safety or tf_agent._target_safety loaded_train_steps = global_step.numpy() logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir, loaded_train_steps) global_step.assign(0) tf.summary.experimental.set_step(global_step) thresholds = [target_safety, 0.5] sc_metrics = [ tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds) ] if seed: tf.compat.v1.set_random_seed(seed) summaries_flush_secs = 10 timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S') offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp) config_saver = gin.tf.GinConfigSaverHook(offline_train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() sc_summary_writer = tf.compat.v2.summary.create_file_writer( offline_train_dir, flush_millis=summaries_flush_secs * 1000) sc_summary_writer.set_as_default() if safety_critic_kernel_scale is not None: ki = tf.compat.v1.variance_scaling_initializer( scale=safety_critic_kernel_scale, mode='fan_in', distribution='truncated_normal') else: ki = tf.compat.v1.keras.initializers.VarianceScaling( scale=1. / 3., mode='fan_in', distribution='uniform') if safety_critic_bias_init_val is not None: bi = tf.constant_initializer(safety_critic_bias_init_val) else: bi = None sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=safety_critic_joint_fc_layers, kernel_initializer=ki, value_bias_initializer=bi, name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') optimizer = tf.keras.optimizers.Adam(safety_critic_lr) sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic') sc_checkpointer = common.Checkpointer( ckpt_dir=sc_net_off_ckpt_dir, safety_critic=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=optimizer, global_step=global_step, max_to_keep=5) sc_checkpointer.initialize_or_restore() resample_counter = py_metrics.CounterMetric('ActionResampleCounter') eval_policy = agents.SafeActorPolicyRSVar( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_net, safety_critic_network=sc_net_off, safety_threshold=target_safety, resample_counter=resample_counter, training=True) dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2, sample_batch_size=batch_size // 2).prefetch(3) data = iter(dataset) full_data = replay_buffer.gather_all() fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool) fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, fail_mask), full_data) init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data) before_fail_mask = tf.roll(fail_mask, [-1], axis=[1]) after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1]) before_fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data) after_init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, after_init_mask), full_data) filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask)) filter_mask = tf.pad( filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]]) n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy() failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=n_failures, dataset_window_shift=1) data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask) sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size // 2, num_steps=2).prefetch(3) neg_data = iter(sc_dataset_neg) get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0] eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step, after_init_step, get_action) losses = [] mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss') target_update = train_utils.get_target_updater(sc_net_off, target_sc_net_off) with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): while global_step.numpy() < num_global_steps: pos_experience, _ = next(data) neg_experience, _ = next(neg_data) exp = data_utils.concat_batches(pos_experience, neg_experience, collect_data_spec) boundary_mask = tf.logical_not(exp.is_boundary()[:, 0]) exp = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, boundary_mask), exp) safe_rew = exp.observation['task_agn_rew'][:, 1] if fail_weight: weights = tf.where(tf.cast(safe_rew, tf.bool), fail_weight / 0.5, (1 - fail_weight) / 0.5) else: weights = None train_loss, sc_loss, lam_loss = train_step( exp, safe_rew, tf_agent, sc_net=sc_net_off, target_sc_net=target_sc_net_off, metrics=sc_metrics, weights=weights, target_safety=target_safety, optimizer=optimizer, target_update=target_update, debug_summaries=debug_summaries) global_step.assign_add(1) global_step_val = global_step.numpy() losses.append( (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy())) mean_loss(train_loss) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='sc_loss', data=sc_loss, step=global_step_val) tf.compat.v2.summary.scalar(name='lam_loss', data=lam_loss, step=global_step_val) if global_step_val % summary_interval == 0: tf.compat.v2.summary.scalar(name=mean_loss.name, data=mean_loss.result(), step=global_step_val) if global_step_val % summary_interval == 0: with tf.name_scope('Metrics'): for metric in sc_metrics: if len(tf.squeeze(metric.result()).shape) == 0: tf.compat.v2.summary.scalar(name=metric.name, data=metric.result(), step=global_step_val) else: fmt_str = '_{}'.format(thresholds[0]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[0], step=global_step_val) fmt_str = '_{}'.format(thresholds[1]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[1], step=global_step_val) metric.reset_states() if global_step_val % eval_interval == 0: eval_sc(sc_net_off, step=global_step_val) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries( train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and global_step_val % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action( monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug( 'saved rollout at timestep %d, rollout length: %d, %4.2f sec', global_step_val, ep_len, time.time() - monitor_start) if global_step_val % train_checkpoint_interval == 0: sc_checkpointer.save(global_step=global_step_val)
def prepare_scannet_scene_dataset(inputs, valid_object_classes=None): """Maps the fields from loaded input to standard fields. Args: inputs: A dictionary of input tensors. valid_object_classes: List of valid object classes. if None, it is ignored. Returns: A dictionary of input tensors with standard field names. """ prepared_inputs = {} if 'mesh/vertices/positions' in inputs: prepared_inputs[standard_fields.InputDataFields .point_positions] = inputs['mesh/vertices/positions'] if 'mesh/vertices/normals' in inputs: prepared_inputs[standard_fields.InputDataFields .point_normals] = inputs['mesh/vertices/normals'] prepared_inputs[standard_fields.InputDataFields.point_normals] = tf.where( tf.math.is_nan( prepared_inputs[standard_fields.InputDataFields.point_normals]), tf.zeros_like( prepared_inputs[standard_fields.InputDataFields.point_normals]), prepared_inputs[standard_fields.InputDataFields.point_normals]) if 'mesh/vertices/colors' in inputs: prepared_inputs[standard_fields.InputDataFields .point_colors] = inputs['mesh/vertices/colors'][:, 0:3] prepared_inputs[standard_fields.InputDataFields.point_colors] = tf.cast( prepared_inputs[standard_fields.InputDataFields.point_colors], dtype=tf.float32) prepared_inputs[standard_fields.InputDataFields.point_colors] *= (2.0 / 255.0) prepared_inputs[standard_fields.InputDataFields.point_colors] -= 1.0 if 'scene_name' in inputs: prepared_inputs[standard_fields.InputDataFields .camera_image_name] = inputs['scene_name'] if 'mesh/vertices/semantic_labels' in inputs: prepared_inputs[ standard_fields.InputDataFields .object_class_points] = inputs['mesh/vertices/semantic_labels'] if 'mesh/vertices/instance_labels' in inputs: prepared_inputs[ standard_fields.InputDataFields.object_instance_id_points] = tf.reshape( inputs['mesh/vertices/instance_labels'], [-1]) if valid_object_classes is not None: valid_objects_mask = tf.cast( tf.zeros_like( prepared_inputs[ standard_fields.InputDataFields.object_class_points], dtype=tf.int32), dtype=tf.bool) for object_class in valid_object_classes: valid_objects_mask = tf.logical_or( valid_objects_mask, tf.equal( prepared_inputs[ standard_fields.InputDataFields.object_class_points], object_class)) valid_objects_mask = tf.cast( valid_objects_mask, dtype=prepared_inputs[ standard_fields.InputDataFields.object_class_points].dtype) prepared_inputs[standard_fields.InputDataFields .object_class_points] *= valid_objects_mask return prepared_inputs
def prepare_kitti_dataset(inputs, valid_object_classes=None): """Maps the fields from loaded input to standard fields. Args: inputs: A dictionary of input tensors. valid_object_classes: List of valid object classes. if None, it is ignored. Returns: A dictionary of input tensors with standard field names. """ prepared_inputs = {} prepared_inputs[standard_fields.InputDataFields.point_positions] = inputs[ standard_fields.InputDataFields.point_positions] prepared_inputs[standard_fields.InputDataFields.point_intensities] = inputs[ standard_fields.InputDataFields.point_intensities] prepared_inputs[standard_fields.InputDataFields .camera_intrinsics] = inputs['cameras/cam02/intrinsics/K'] prepared_inputs[standard_fields.InputDataFields. camera_rotation_matrix] = inputs['cameras/cam02/extrinsics/R'] prepared_inputs[standard_fields.InputDataFields .camera_translation] = inputs['cameras/cam02/extrinsics/t'] prepared_inputs[standard_fields.InputDataFields .camera_image] = inputs['cameras/cam02/image'] prepared_inputs[standard_fields.InputDataFields .camera_raw_image] = inputs['cameras/cam02/image'] prepared_inputs[standard_fields.InputDataFields .camera_original_image] = inputs['cameras/cam02/image'] if 'scene_name' in inputs and 'frame_name' in inputs: prepared_inputs[ standard_fields.InputDataFields.camera_image_name] = tf.strings.join( [inputs['scene_name'], inputs['frame_name']], separator='_') if 'objects/pose/R' in inputs: prepared_inputs[standard_fields.InputDataFields .objects_rotation_matrix] = inputs['objects/pose/R'] if 'objects/pose/t' in inputs: prepared_inputs[standard_fields.InputDataFields .objects_center] = inputs['objects/pose/t'] if 'objects/shape/dimension' in inputs: prepared_inputs[ standard_fields.InputDataFields.objects_length] = tf.reshape( inputs['objects/shape/dimension'][:, 0], [-1, 1]) prepared_inputs[standard_fields.InputDataFields.objects_width] = tf.reshape( inputs['objects/shape/dimension'][:, 1], [-1, 1]) prepared_inputs[ standard_fields.InputDataFields.objects_height] = tf.reshape( inputs['objects/shape/dimension'][:, 2], [-1, 1]) if 'objects/category/label' in inputs: prepared_inputs[standard_fields.InputDataFields.objects_class] = tf.reshape( inputs['objects/category/label'], [-1, 1]) if valid_object_classes is not None: valid_objects_mask = tf.cast( tf.zeros_like( prepared_inputs[standard_fields.InputDataFields.objects_class], dtype=tf.int32), dtype=tf.bool) for object_class in valid_object_classes: valid_objects_mask = tf.logical_or( valid_objects_mask, tf.equal( prepared_inputs[standard_fields.InputDataFields.objects_class], object_class)) valid_objects_mask = tf.reshape(valid_objects_mask, [-1]) for key in standard_fields.get_input_object_fields(): if key in prepared_inputs: prepared_inputs[key] = tf.boolean_mask(prepared_inputs[key], valid_objects_mask) return prepared_inputs
def prepare_waymo_open_dataset(inputs, valid_object_classes=None, max_object_distance_from_source=74.88): """Maps the fields from loaded input to standard fields. Args: inputs: A dictionary of input tensors. valid_object_classes: List of valid object classes. if None, it is ignored. max_object_distance_from_source: Maximum distance of objects from source. It will be ignored if None. Returns: A dictionary of input tensors with standard field names. """ prepared_inputs = {} if standard_fields.InputDataFields.point_positions in inputs: prepared_inputs[standard_fields.InputDataFields.point_positions] = inputs[ standard_fields.InputDataFields.point_positions] if standard_fields.InputDataFields.point_intensities in inputs: prepared_inputs[standard_fields.InputDataFields.point_intensities] = inputs[ standard_fields.InputDataFields.point_intensities] if standard_fields.InputDataFields.point_elongations in inputs: prepared_inputs[standard_fields.InputDataFields.point_elongations] = inputs[ standard_fields.InputDataFields.point_elongations] if standard_fields.InputDataFields.point_normals in inputs: prepared_inputs[standard_fields.InputDataFields.point_normals] = inputs[ standard_fields.InputDataFields.point_normals] if 'cameras/front/intrinsics/K' in inputs: prepared_inputs[standard_fields.InputDataFields .camera_intrinsics] = inputs['cameras/front/intrinsics/K'] if 'cameras/front/extrinsics/R' in inputs: prepared_inputs[ standard_fields.InputDataFields .camera_rotation_matrix] = inputs['cameras/front/extrinsics/R'] if 'cameras/front/extrinsics/t' in inputs: prepared_inputs[standard_fields.InputDataFields .camera_translation] = inputs['cameras/front/extrinsics/t'] if 'cameras/front/image' in inputs: prepared_inputs[standard_fields.InputDataFields .camera_image] = inputs['cameras/front/image'] prepared_inputs[standard_fields.InputDataFields .camera_raw_image] = inputs['cameras/front/image'] prepared_inputs[standard_fields.InputDataFields .camera_original_image] = inputs['cameras/front/image'] if 'scene_name' in inputs and 'frame_name' in inputs: prepared_inputs[ standard_fields.InputDataFields.camera_image_name] = tf.strings.join( [inputs['scene_name'], inputs['frame_name']], separator='_') if 'objects/pose/R' in inputs: prepared_inputs[standard_fields.InputDataFields .objects_rotation_matrix] = inputs['objects/pose/R'] if 'objects/pose/t' in inputs: prepared_inputs[standard_fields.InputDataFields .objects_center] = inputs['objects/pose/t'] if 'objects/shape/dimension' in inputs: prepared_inputs[ standard_fields.InputDataFields.objects_length] = tf.reshape( inputs['objects/shape/dimension'][:, 0], [-1, 1]) prepared_inputs[standard_fields.InputDataFields.objects_width] = tf.reshape( inputs['objects/shape/dimension'][:, 1], [-1, 1]) prepared_inputs[ standard_fields.InputDataFields.objects_height] = tf.reshape( inputs['objects/shape/dimension'][:, 2], [-1, 1]) if 'objects/category/label' in inputs: prepared_inputs[standard_fields.InputDataFields.objects_class] = tf.reshape( inputs['objects/category/label'], [-1, 1]) if valid_object_classes is not None: valid_objects_mask = tf.cast( tf.zeros_like( prepared_inputs[standard_fields.InputDataFields.objects_class], dtype=tf.int32), dtype=tf.bool) for object_class in valid_object_classes: valid_objects_mask = tf.logical_or( valid_objects_mask, tf.equal( prepared_inputs[standard_fields.InputDataFields.objects_class], object_class)) valid_objects_mask = tf.reshape(valid_objects_mask, [-1]) for key in standard_fields.get_input_object_fields(): if key in prepared_inputs: prepared_inputs[key] = tf.boolean_mask(prepared_inputs[key], valid_objects_mask) if max_object_distance_from_source is not None: if standard_fields.InputDataFields.objects_center in prepared_inputs: object_distances = tf.norm( prepared_inputs[standard_fields.InputDataFields.objects_center][:, 0:2], axis=1) valid_mask = tf.less(object_distances, max_object_distance_from_source) for key in standard_fields.get_input_object_fields(): if key in prepared_inputs: prepared_inputs[key] = tf.boolean_mask(prepared_inputs[key], valid_mask) return prepared_inputs
def prepare_scannet_frame_dataset(inputs, min_pixel_depth=0.3, max_pixel_depth=6.0, valid_object_classes=None): """Maps the fields from loaded input to standard fields. Args: inputs: A dictionary of input tensors. min_pixel_depth: Pixels with depth values less than this are pruned. max_pixel_depth: Pixels with depth values more than this are pruned. valid_object_classes: List of valid object classes. if None, it is ignored. Returns: A dictionary of input tensors with standard field names. """ prepared_inputs = {} if 'cameras/rgbd_camera/intrinsics/K' not in inputs: raise ValueError('Intrinsic matrix is missing.') if 'cameras/rgbd_camera/extrinsics/R' not in inputs: raise ValueError('Extrinsic rotation matrix is missing.') if 'cameras/rgbd_camera/extrinsics/t' not in inputs: raise ValueError('Extrinsics translation is missing.') if 'cameras/rgbd_camera/depth_image' not in inputs: raise ValueError('Depth image is missing.') if 'cameras/rgbd_camera/color_image' not in inputs: raise ValueError('Color image is missing.') if 'frame_name' in inputs: prepared_inputs[standard_fields.InputDataFields .camera_image_name] = inputs['frame_name'] camera_intrinsics = inputs['cameras/rgbd_camera/intrinsics/K'] depth_image = inputs['cameras/rgbd_camera/depth_image'] image_height = tf.shape(depth_image)[0] image_width = tf.shape(depth_image)[1] x, y = tf.meshgrid( tf.range(image_width), tf.range(image_height), indexing='xy') x = tf.reshape(tf.cast(x, dtype=tf.float32) + 0.5, [-1, 1]) y = tf.reshape(tf.cast(y, dtype=tf.float32) + 0.5, [-1, 1]) point_positions = projections.image_frame_to_camera_frame( image_frame=tf.concat([x, y], axis=1), camera_intrinsics=camera_intrinsics) rotate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/R'] translate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/t'] point_positions = projections.to_world_frame( camera_frame_points=point_positions, rotate_world_to_camera=rotate_world_to_camera, translate_world_to_camera=translate_world_to_camera) prepared_inputs[standard_fields.InputDataFields .point_positions] = point_positions * tf.reshape( depth_image, [-1, 1]) depth_values = tf.reshape(depth_image, [-1]) valid_depth_mask = tf.logical_and( tf.greater_equal(depth_values, min_pixel_depth), tf.less_equal(depth_values, max_pixel_depth)) prepared_inputs[standard_fields.InputDataFields.point_colors] = tf.reshape( tf.cast(inputs['cameras/rgbd_camera/color_image'], dtype=tf.float32), [-1, 3]) prepared_inputs[standard_fields.InputDataFields.point_colors] *= (2.0 / 255.0) prepared_inputs[standard_fields.InputDataFields.point_colors] -= 1.0 prepared_inputs[ standard_fields.InputDataFields.point_positions] = tf.boolean_mask( prepared_inputs[standard_fields.InputDataFields.point_positions], valid_depth_mask) prepared_inputs[ standard_fields.InputDataFields.point_colors] = tf.boolean_mask( prepared_inputs[standard_fields.InputDataFields.point_colors], valid_depth_mask) if 'cameras/rgbd_camera/semantic_image' in inputs: prepared_inputs[ standard_fields.InputDataFields.object_class_points] = tf.cast( tf.reshape(inputs['cameras/rgbd_camera/semantic_image'], [-1, 1]), dtype=tf.int32) prepared_inputs[ standard_fields.InputDataFields.object_class_points] = tf.boolean_mask( prepared_inputs[ standard_fields.InputDataFields.object_class_points], valid_depth_mask) if 'cameras/rgbd_camera/instance_image' in inputs: prepared_inputs[ standard_fields.InputDataFields.object_instance_id_points] = tf.cast( tf.reshape(inputs['cameras/rgbd_camera/instance_image'], [-1]), dtype=tf.int32) prepared_inputs[standard_fields.InputDataFields .object_instance_id_points] = tf.boolean_mask( prepared_inputs[standard_fields.InputDataFields .object_instance_id_points], valid_depth_mask) if valid_object_classes is not None: valid_objects_mask = tf.cast( tf.zeros_like( prepared_inputs[ standard_fields.InputDataFields.object_class_points], dtype=tf.int32), dtype=tf.bool) for object_class in valid_object_classes: valid_objects_mask = tf.logical_or( valid_objects_mask, tf.equal( prepared_inputs[ standard_fields.InputDataFields.object_class_points], object_class)) valid_objects_mask = tf.cast( valid_objects_mask, dtype=prepared_inputs[ standard_fields.InputDataFields.object_class_points].dtype) prepared_inputs[standard_fields.InputDataFields .object_class_points] *= valid_objects_mask return prepared_inputs