def check_constraint_batched_tf(self, environment: Dict, predictions: Dict, actions: Dict, batch_size: int, state_sequence_length: int): # construct network inputs net_inputs = { 'batch_size': batch_size, 'time': state_sequence_length, } net_inputs.update(make_dict_tf_float32(environment)) for action_key in self.action_keys: net_inputs[action_key] = tf.cast(actions[action_key], tf.float32) for state_key in self.state_keys: planned_state_key = add_predicted(state_key) net_inputs[planned_state_key] = tf.cast(predictions[state_key], tf.float32) if self.hparams['stdev']: net_inputs[add_predicted('stdev')] = tf.cast(predictions['stdev'], tf.float32) net_inputs = make_dict_tf_float32(net_inputs) mean_predictions, stdev_predictions = self.check_constraint_from_example(net_inputs, training=False) mean_probability = mean_predictions['probabilities'] stdev_probability = stdev_predictions['probabilities'] mean_probability = tf.squeeze(mean_probability, axis=2) stdev_probability = tf.squeeze(stdev_probability, axis=2) return mean_probability, stdev_probability
def __init__( self, dataset_dirs: List[pathlib.Path], use_gt_rope: bool, load_true_states=False, no_balance=True, threshold: Optional[float] = None, old_compat: Optional[bool] = False, ): super(ClassifierDatasetLoader, self).__init__(dataset_dirs) self.no_balance = no_balance self.load_true_states = load_true_states self.use_gt_rope = use_gt_rope self.labeling_params = self.hparams['labeling_params'] self.threshold = threshold if threshold is not None else self.labeling_params[ 'threshold'] rospy.loginfo(f"classifier using threshold {self.threshold}") self.horizon = self.hparams['labeling_params']['classifier_horizon'] self.scenario = get_scenario(self.hparams['scenario']) self.true_state_keys = self.hparams['true_state_keys'] self.old_compat = old_compat if self.old_compat: self.true_state_keys.append('is_close') else: self.true_state_keys.append('error') self.predicted_state_keys = [ add_predicted(k) for k in self.hparams['predicted_state_keys'] ] self.predicted_state_keys.append(add_predicted('stdev')) self.action_keys = self.hparams['action_keys'] self.feature_names = [ 'classifier_start_t', 'classifier_end_t', 'env', 'extent', 'origin', 'res', 'traj_idx', 'prediction_start_t', ] self.batch_metadata = {'time': self.horizon} if self.load_true_states: for k in self.true_state_keys: self.feature_names.append(k) for k in self.predicted_state_keys: self.feature_names.append(k) for k in self.action_keys: self.feature_names.append(k)
def __init__(self, hparams: Dict, batch_size: int, scenario: Base3DScenario): super().__init__(hparams, batch_size) self.scenario = scenario self.raster_debug_pubs = [ rospy.Publisher(f'classifier_raster_debug_{i}', OccupancyStamped, queue_size=10, latch=False) for i in range(4)] self.local_env_bbox_pub = rospy.Publisher('local_env_bbox', BoundingBox, queue_size=10, latch=True) self.classifier_dataset_hparams = self.hparams['classifier_dataset_hparams'] self.dynamics_dataset_hparams = self.classifier_dataset_hparams['fwd_model_hparams']['dynamics_dataset_hparams'] self.true_state_keys = self.classifier_dataset_hparams['true_state_keys'] self.pred_state_keys = [add_predicted(k) for k in self.classifier_dataset_hparams['predicted_state_keys']] self.pred_state_keys.append(add_predicted('stdev')) self.local_env_h_rows = self.hparams['local_env_h_rows'] self.local_env_w_cols = self.hparams['local_env_w_cols'] self.local_env_c_channels = self.hparams['local_env_c_channels'] self.rope_image_k = self.hparams['rope_image_k'] # TODO: add stdev to states keys? self.state_keys = self.hparams['state_keys'] self.action_keys = self.hparams['action_keys'] self.conv_layers = [] self.pool_layers = [] for n_filters, kernel_size in self.hparams['conv_filters']: conv = layers.Conv3D(n_filters, kernel_size, activation='relu', kernel_regularizer=keras.regularizers.l2(self.hparams['kernel_reg']), bias_regularizer=keras.regularizers.l2(self.hparams['bias_reg'])) pool = layers.MaxPool3D(self.hparams['pooling']) self.conv_layers.append(conv) self.pool_layers.append(pool) if self.hparams['batch_norm']: self.batch_norm = layers.BatchNormalization() self.dense_layers = [] for hidden_size in self.hparams['fc_layer_sizes']: dense = layers.Dense(hidden_size, activation='relu', kernel_regularizer=keras.regularizers.l2(self.hparams['kernel_reg']), bias_regularizer=keras.regularizers.l2(self.hparams['bias_reg'])) self.dense_layers.append(dense) # self.local_env_shape = (self.local_env_h_rows, self.local_env_w_cols, self.local_env_c_channels) # self.encoder = tf.keras.applications.ResNet50(include_top=False, weights=None, input_shape=self.local_env_shape) self.lstm = layers.LSTM(self.hparams['rnn_size'], unroll=True, return_sequences=True) self.output_layer = layers.Dense(1, activation=None) self.sigmoid = layers.Activation("sigmoid")
def plot_state_rviz(self, state: Dict, label: str, **kwargs): r, g, b, a = colors.to_rgba(kwargs.get("color", "r")) idx = kwargs.get("idx", 0) ig = marker_index_generator(idx) msg = MarkerArray() if rope_key_name in state: rope_points = np.reshape(state[rope_key_name], [-1, 3]) rope_mrkrs = make_rope_marker(rope_points, 'world', label + "_gt_" + rope_key_name, next(ig), r, g, b, a, s=0.04) points_marker, lines, midpoint_sphere, first_point_text = rope_mrkrs msg.markers.append(lines) # msg.markers.append(points_marker) # msg.markers.append(first_point_text) if 'gripper' in state: gripper = state['gripper'] gripper_sphere = make_gripper_marker(gripper, next(ig), r, g, b, a, label + '_gt_gripper', Marker.SPHERE, s=0.04) msg.markers.append(gripper_sphere) if add_predicted(rope_key_name) in state: rope_points = np.reshape(state[add_predicted(rope_key_name)], [-1, 3]) markers = make_rope_marker(rope_points, 'world', label + "_" + rope_key_name, next(ig), r, g, b, a) msg.markers.extend(markers) if add_predicted('gripper') in state: pred_gripper = state[add_predicted('gripper')] gripper_sphere = make_gripper_marker(pred_gripper, next(ig), r, g, b, a, label + "_gripper", Marker.SPHERE) msg.markers.append(gripper_sphere) self.state_viz_pub.publish(msg)
def call(self, input_dict: Dict, training, **kwargs): batch_size = input_dict['batch_size'] time = tf.cast(input_dict['time'], tf.int32) conv_output, debug_info_seq = self.make_traj_voxel_grids_from_input_dict(input_dict, batch_size, time) states = {k: input_dict[add_predicted(k)] for k in self.state_keys} states_in_local_frame = self.scenario.put_state_local_frame(states) actions = {k: input_dict[k] for k in self.action_keys} all_but_last_states = {k: v[:, :-1] for k, v in states.items()} actions = self.scenario.put_action_local_frame(all_but_last_states, actions) padded_actions = [tf.pad(v, [[0, 0], [0, 1], [0, 0]]) for v in actions.values()] if 'with_robot_frame' not in self.hparams: print("no hparam 'with_robot_frame'. This must be an old model!") concat_args = [conv_output] + list(states_in_local_frame.values()) + padded_actions elif self.hparams['with_robot_frame']: states_in_robot_frame = self.scenario.put_state_robot_frame(states) concat_args = ([conv_output] + list(states_in_robot_frame.values()) + list(states_in_local_frame.values()) + padded_actions) else: concat_args = [conv_output] + list(states_in_local_frame.values()) + padded_actions if self.hparams['stdev']: stdevs = input_dict[add_predicted('stdev')] concat_args.append(stdevs) concat_output = tf.concat(concat_args, axis=2) if self.hparams['batch_norm']: concat_output = self.batch_norm(concat_output, training=training) z = concat_output for dense_layer in self.dense_layers: z = dense_layer(z) out_d = z out_h = self.lstm(out_d) # for every timestep's output, map down to a single scalar, the logit for accept probability all_accept_logits = self.output_layer(out_h) # ignore the first output, it is meaningless to predict the validity of a single state valid_accept_logits = all_accept_logits[:, 1:] valid_accept_probabilities = self.sigmoid(valid_accept_logits) if DEBUG_VIZ: self.debug_rviz(input_dict, debug_info_seq) return { 'logits': valid_accept_logits, 'probabilities': valid_accept_probabilities, }
def state_encoder(self, input_dict, batch_size, training): # get only the start states start_state = {k: input_dict[add_predicted(k)][:, 0] for k in self.states_keys} # tile to the number of actions local_env_center_point = self.scenario.local_environment_center_differentiable(start_state) images = self.make_trajectory_images(environment=self.scenario.get_environment_from_example(input_dict), start_states=start_state, local_env_center_point=local_env_center_point, batch_size=batch_size) # import matplotlib.pyplot as plt # from matplotlib import cm # cmap = cm.viridis # out_image = state_image_to_cmap(images[0], cmap=cmap) # plt.imshow(out_image) # plt.show() conv_output = self._conv(images, batch_size) concat_args = [conv_output] for k, v in start_state.items(): # note this assumes all state vectors are[x1,y1,...,xn,yn] points = tf.reshape(v, [batch_size, -1, 2]) points = points - points[:, :, tf.newaxis, 0] v = tf.reshape(points, [batch_size, -1]) concat_args.append(v) conv_output = tf.concat(concat_args, axis=1) if self.hparams['batch_norm']: conv_output = self.batch_norm(conv_output, training=training) z = conv_output for dense_layer in self.env_state_encoder_dense_layers: z = dense_layer(z) return images, z
def sample(self, environment: Dict, state: Dict): input_dict = environment input_dict.update({add_predicted(k): tf.expand_dims(v, axis=0) for k, v in state.items()}) input_dict = add_batch(input_dict) input_dict = {k: tf.cast(v, tf.float32) for k, v in input_dict.items()} output = self.net.sample(input_dict) output = remove_batch(output) output = numpify(output) return output
def test_add_remove_predicted_dict(self): d = { add_predicted("test1"): 1, "test2": 2, } expected_d = { "test1": 1, "test2": 2, } out_d = remove_predicted_from_dict(d) self.assertEqual(expected_d, out_d)
def _process_example(dataset: ClassifierDatasetLoader, example: Dict): example['left_gripper'] = example.pop('gripper1') example['right_gripper'] = example.pop('gripper2') example['left_gripper_position'] = example.pop('gripper1_position') example['right_gripper_position'] = example.pop('gripper2_position') example['rope'] = example.pop('link_bot') example[add_predicted('left_gripper')] = example.pop( add_predicted('gripper1')) example[add_predicted('right_gripper')] = example.pop( add_predicted('gripper2')) example[add_predicted('rope')] = example.pop(add_predicted('link_bot')) yield example
def visualize_dataset(args, classifier_dataset): tf_dataset = classifier_dataset.get_datasets(mode=args.mode, take=args.take) tf_dataset = tf_dataset.batch(1) iterator = iter(tf_dataset) t0 = perf_counter() reconverging_count = 0 positive_count = 0 negative_count = 0 count = 0 stdevs = [] labels = [] stdevs_for_negative = [] stdevs_for_positive = [] for i, example in enumerate(progressbar(tf_dataset, widgets=base_dataset.widgets)): example = remove_batch(example) is_close = example['is_close'].numpy().squeeze() count += is_close.shape[0] n_close = np.count_nonzero(is_close[-1]) n_far = is_close.shape[0] - n_close positive_count += n_close negative_count += n_far reconverging = n_far > 0 and is_close[-1] if args.only_reconverging and not reconverging: continue if args.only_negative and np.any(is_close[1:]): continue if args.only_positive and not np.any(is_close[1:]): continue # print(f"Example {i}, Trajectory #{int(example['traj_idx'])}") if count == 0: print_dict(example) if reconverging: reconverging_count += 1 # Print statistics intermittently if count % 1000 == 0: print_stats_and_timing(args, count, reconverging_count, negative_count, positive_count) ############################# # Show Visualization ############################# if args.display_type == 'just_count': continue elif args.display_type == '3d': # print(example['is_close']) if example['is_close'][0] == 0: continue classifier_dataset.anim_transition_rviz(example) elif args.display_type == 'stdev': for t in range(1, classifier_dataset.horizon): stdev_t = example[add_predicted('stdev')][t, 0].numpy() label_t = example['is_close'][t] stdevs.append(stdev_t) labels.append(label_t) if label_t > 0.5: stdevs_for_positive.append(stdev_t) else: stdevs_for_negative.append(stdev_t) else: raise NotImplementedError() total_dt = perf_counter() - t0 if args.display_type == 'stdev': print(f"p={stats.f_oneway(stdevs_for_negative, stdevs_for_positive)[1]}") plt.figure() plt.title(" ".join([str(d.name) for d in args.dataset_dirs])) bins = plt.hist(stdevs_for_negative, label='negative examples', alpha=0.8, density=True)[1] plt.hist(stdevs_for_positive, label='positive examples', alpha=0.8, bins=bins, density=True) plt.ylabel("count") plt.xlabel("stdev") plt.legend() plt.show() print_stats_and_timing(args, count, reconverging_count, negative_count, positive_count, total_dt)
def viz_ensemble_main(dataset_dir: pathlib.Path, checkpoints: List[pathlib.Path], mode: str, batch_size: int, only_errors: bool, use_gt_rope: bool, **kwargs): dynamics_stdev_pub_ = rospy.Publisher("dynamics_stdev", Float32, queue_size=10) classifier_stdev_pub_ = rospy.Publisher("classifier_stdev", Float32, queue_size=10) accept_probability_pub_ = rospy.Publisher("accept_probability_viz", Float32, queue_size=10) traj_idx_pub_ = rospy.Publisher("traj_idx_viz", Float32, queue_size=10) ############### # Model ############### model = load_generic_model(checkpoints) ############### # Dataset ############### test_dataset = ClassifierDatasetLoader([dataset_dir], load_true_states=True, use_gt_rope=use_gt_rope) test_tf_dataset = test_dataset.get_datasets(mode=mode) test_tf_dataset = batch_tf_dataset(test_tf_dataset, batch_size, drop_remainder=True) scenario = test_dataset.scenario ############### # Evaluate ############### # Iterate over test set all_accuracies_over_time = [] all_stdevs = [] all_labels = [] classifier_ensemble_stdevs = [] for batch_idx, test_batch in enumerate(test_tf_dataset): test_batch.update(test_dataset.batch_metadata) mean_predictions, stdev_predictions = model.check_constraint_from_example( test_batch) mean_probabilities = mean_predictions['probabilities'] stdev_probabilities = stdev_predictions['probabilities'] labels = tf.expand_dims(test_batch['is_close'][:, 1:], axis=2) all_labels = tf.concat( (all_labels, tf.reshape(test_batch['is_close'][:, 1:], [-1])), axis=0) all_stdevs = tf.concat( (all_stdevs, tf.reshape(test_batch[add_predicted('stdev')], [-1])), axis=0) accuracy_over_time = tf.keras.metrics.binary_accuracy( y_true=labels, y_pred=mean_probabilities) all_accuracies_over_time.append(accuracy_over_time) # Visualization test_batch.pop("time") test_batch.pop("batch_size") decisions = mean_probabilities > 0.5 classifier_is_correct = tf.squeeze(tf.equal(decisions, tf.cast(labels, tf.bool)), axis=-1) for b in range(batch_size): example = index_dict_of_batched_tensors_tf(test_batch, b) classifier_ensemble_stdev = stdev_probabilities[b].numpy().squeeze( ) classifier_ensemble_stdevs.append(classifier_ensemble_stdev) # if the classifier is correct at all time steps, ignore if only_errors and tf.reduce_all(classifier_is_correct[b]): continue # if only_collision predicted_rope_states = tf.reshape( example[add_predicted('link_bot')][1], [-1, 3]) xs = predicted_rope_states[:, 0] ys = predicted_rope_states[:, 1] zs = predicted_rope_states[:, 2] in_collision = bool( batch_in_collision_tf_3d(environment=example, xs=xs, ys=ys, zs=zs, inflate_radius_m=0)[0].numpy()) label = bool(example['is_close'][1].numpy()) accept = decisions[b, 0, 0].numpy() # if not (in_collision and accept): # continue scenario.plot_environment_rviz(example) stdev_probabilities[b].numpy().squeeze() classifier_stdev_msg = Float32() classifier_stdev_msg.data = stdev_probabilities[b].numpy().squeeze( ) classifier_stdev_pub_.publish(classifier_stdev_msg) actual_0 = scenario.index_state_time(example, 0) actual_1 = scenario.index_state_time(example, 1) pred_0 = scenario.index_predicted_state_time(example, 0) pred_1 = scenario.index_predicted_state_time(example, 1) action = scenario.index_action_time(example, 0) label = example['is_close'][1] scenario.plot_state_rviz(actual_0, label='actual', color='#FF0000AA', idx=0) scenario.plot_state_rviz(actual_1, label='actual', color='#E00016AA', idx=1) scenario.plot_state_rviz(pred_0, label='predicted', color='#0000FFAA', idx=0) scenario.plot_state_rviz(pred_1, label='predicted', color='#0553FAAA', idx=1) scenario.plot_action_rviz(pred_0, action) scenario.plot_is_close(label) dynamics_stdev_t = example[add_predicted('stdev')][1, 0].numpy() dynamics_stdev_msg = Float32() dynamics_stdev_msg.data = dynamics_stdev_t dynamics_stdev_pub_.publish(dynamics_stdev_msg) accept_probability_t = mean_probabilities[b, 0, 0].numpy() accept_probability_msg = Float32() accept_probability_msg.data = accept_probability_t accept_probability_pub_.publish(accept_probability_msg) traj_idx_msg = Float32() traj_idx_msg.data = batch_idx * batch_size + b traj_idx_pub_.publish(traj_idx_msg) # stepper = RvizSimpleStepper() # stepper.step() print(np.mean(classifier_ensemble_stdevs)) all_accuracies_over_time = tf.concat(all_accuracies_over_time, axis=0) mean_accuracies_over_time = tf.reduce_mean(all_accuracies_over_time, axis=0) std_accuracies_over_time = tf.math.reduce_std(all_accuracies_over_time, axis=0) mean_classifier_ensemble_stdev = tf.reduce_mean(classifier_ensemble_stdevs) print(mean_classifier_ensemble_stdev)
def plot_state_rviz(self, state: Dict, label: str, **kwargs): r, g, b, a = colors.to_rgba(kwargs.get("color", "r")) idx = kwargs.get("idx", 0) msg = MarkerArray() ig = marker_index_generator(idx) if 'gt_rope' in state: rope_points = np.reshape(state['gt_rope'], [-1, 3]) markers = make_rope_marker(rope_points, 'world', label + "_gt_rope", next(ig), r, g, b, a) msg.markers.extend(markers) if 'rope' in state: rope_points = np.reshape(state['rope'], [-1, 3]) markers = make_rope_marker(rope_points, 'world', label + "_rope", next(ig), r, g, b, a) msg.markers.extend(markers) if add_predicted('rope') in state: rope_points = np.reshape(state[add_predicted('rope')], [-1, 3]) markers = make_rope_marker(rope_points, 'world', label + "_pred_rope", next(ig), r, g, b, a, Marker.CUBE_LIST) msg.markers.extend(markers) if 'left_gripper' in state: r = 0.2 g = 0.2 b = 0.8 left_gripper_sphere = make_gripper_marker(state['left_gripper'], next(ig), r, g, b, a, label + '_l', Marker.SPHERE) msg.markers.append(left_gripper_sphere) if 'right_gripper' in state: r = 0.8 g = 0.8 b = 0.2 right_gripper_sphere = make_gripper_marker(state['right_gripper'], next(ig), r, g, b, a, label + "_r", Marker.SPHERE) msg.markers.append(right_gripper_sphere) if add_predicted('left_gripper') in state: r = 0.2 g = 0.2 b = 0.8 lgpp = state[add_predicted('left_gripper')] left_gripper_sphere = make_gripper_marker(lgpp, next(ig), r, g, b, a, label + "_lp", Marker.CUBE) msg.markers.append(left_gripper_sphere) if add_predicted('right_gripper') in state: r = 0.8 g = 0.8 b = 0.2 rgpp = state[add_predicted('right_gripper')] right_gripper_sphere = make_gripper_marker(rgpp, next(ig), r, g, b, a, label + "_rp", Marker.CUBE) msg.markers.append(right_gripper_sphere) s = kwargs.get("scale", 1.0) msg = scale_marker_array(msg, s) self.state_viz_pub.publish(msg) if in_maybe_predicted('rgbd', state): publish_color_image(self.state_color_viz_pub, state['rgbd'][:, :, :3]) publish_depth_image(self.state_depth_viz_pub, state['rgbd'][:, :, 3]) if add_predicted('stdev') in state: stdev_t = state[add_predicted('stdev')][0] self.plot_stdev(stdev_t) if 'error' in state: error_msg = Float32() error_t = state['error'] error_msg.data = error_t self.error_pub.publish(error_msg)
def test_add_remove_predicted(self): k = "test" out_k = remove_predicted(add_predicted(k)) self.assertEqual(k, out_k)
def stdev_viz_t(pub: rospy.Publisher): return float32_viz_t(pub, add_predicted('stdev'))
def generate_classifier_examples_from_batch( scenario: ExperimentScenario, prediction_actual: PredictionActualExample): labeling_params = prediction_actual.labeling_params prediction_horizon = prediction_actual.actual_prediction_horizon classifier_horizon = labeling_params['classifier_horizon'] valid_out_examples = [] for classifier_start_t in range( 0, prediction_horizon - classifier_horizon + 1): classifier_end_t = classifier_start_t + classifier_horizon prediction_start_t = prediction_actual.prediction_start_t prediction_start_t_batched = tf.cast( tf.stack([prediction_start_t] * prediction_actual.batch_size, axis=0), tf.float32) classifier_start_t_batched = tf.cast( tf.stack([classifier_start_t] * prediction_actual.batch_size, axis=0), tf.float32) classifier_end_t_batched = tf.cast( tf.stack([classifier_end_t] * prediction_actual.batch_size, axis=0), tf.float32) out_example = { 'env': prediction_actual.dataset_element['env'], 'origin': prediction_actual.dataset_element['origin'], 'extent': prediction_actual.dataset_element['extent'], 'res': prediction_actual.dataset_element['res'], 'traj_idx': prediction_actual.dataset_element['traj_idx'], 'prediction_start_t': prediction_start_t_batched, 'classifier_start_t': classifier_start_t_batched, 'classifier_end_t': classifier_end_t_batched, } # this slice gives arrays of fixed length (ex, 5) which must be null padded from out_example_end_idx onwards state_slice = slice(classifier_start_t, classifier_start_t + classifier_horizon) action_slice = slice(classifier_start_t, classifier_start_t + classifier_horizon - 1) sliced_actual = {} for key, actual_state_component in prediction_actual.actual_states.items( ): actual_state_component_sliced = actual_state_component[:, state_slice] out_example[key] = actual_state_component_sliced sliced_actual[key] = actual_state_component_sliced sliced_predictions = {} for key, prediction_component in prediction_actual.predictions.items(): prediction_component_sliced = prediction_component[:, state_slice] out_example[add_predicted(key)] = prediction_component_sliced sliced_predictions[key] = prediction_component_sliced # action sliced_actions = {} for key, action_component in prediction_actual.actions.items(): action_component_sliced = action_component[:, action_slice] out_example[key] = action_component_sliced sliced_actions[key] = action_component_sliced # compute label threshold = labeling_params['threshold'] error = scenario.classifier_distance(sliced_actual, sliced_predictions) is_close = error < threshold out_example['error'] = tf.cast(error, dtype=tf.float32) # perception reliability if 'perception_reliability_method' in labeling_params: pr_method = labeling_params['perception_reliability_method'] if pr_method == 'gt': perception_reliability = gt_perception_reliability( scenario, sliced_actual, sliced_predictions) out_example['perception_reliability'] = perception_reliability else: raise NotImplementedError( f"unrecognized perception reliability method {pr_method}") is_first_predicted_state_close = is_close[:, 0] valid_indices = tf.where(is_first_predicted_state_close) valid_indices = tf.squeeze(valid_indices, axis=1) # keep only valid_indices from every key in out_example... valid_out_example = gather_dict(out_example, valid_indices) valid_out_examples.append(valid_out_example) # valid_out_examples.append(out_example) return valid_out_examples
def make_traj_voxel_grids_from_input_dict(self, input_dict: Dict, batch_size, time: int): # Construct a [b, h, w, c, 3] grid of the indices which make up the local environment pixel_row_indices = tf.range(0, self.local_env_h_rows, dtype=tf.float32) pixel_col_indices = tf.range(0, self.local_env_w_cols, dtype=tf.float32) pixel_channel_indices = tf.range(0, self.local_env_c_channels, dtype=tf.float32) x_indices, y_indices, z_indices = tf.meshgrid(pixel_col_indices, pixel_row_indices, pixel_channel_indices) # Make batched versions for creating the local environment batch_y_indices = tf.cast(tf.tile(tf.expand_dims(y_indices, axis=0), [batch_size, 1, 1, 1]), tf.int64) batch_x_indices = tf.cast(tf.tile(tf.expand_dims(x_indices, axis=0), [batch_size, 1, 1, 1]), tf.int64) batch_z_indices = tf.cast(tf.tile(tf.expand_dims(z_indices, axis=0), [batch_size, 1, 1, 1]), tf.int64) # Convert for rastering state pixel_indices = tf.stack([y_indices, x_indices, z_indices], axis=3) pixel_indices = tf.expand_dims(pixel_indices, axis=0) pixel_indices = tf.tile(pixel_indices, [batch_size, 1, 1, 1, 1]) conv_outputs_array = tf.TensorArray(tf.float32, size=0, dynamic_size=True) debug_info_seq = [] for t in tf.range(time): state_t = {k: input_dict[add_predicted(k)][:, t] for k in self.state_keys} local_env_center_t = self.scenario.local_environment_center_differentiable(state_t) # by converting too and from the frame of the full environment, we ensure the grids are aligned indices = batch_point_to_idx_tf_3d_in_batched_envs(local_env_center_t, input_dict) local_env_center_t = batch_idx_to_point_3d_in_env_tf(*indices, input_dict) local_env_t, local_env_origin_t = get_local_env(center_point=local_env_center_t, full_env=input_dict['env'], full_env_origin=input_dict['origin'], res=input_dict['res'], local_h_rows=self.local_env_h_rows, local_w_cols=self.local_env_w_cols, local_c_channels=self.local_env_c_channels, batch_x_indices=batch_x_indices, batch_y_indices=batch_y_indices, batch_z_indices=batch_z_indices, batch_size=batch_size) local_voxel_grid_t_array = tf.TensorArray(tf.float32, size=0, dynamic_size=True) local_voxel_grid_t_array = local_voxel_grid_t_array.write(0, local_env_t) for i, (k, state_component_t) in enumerate(state_t.items()): state_component_voxel_grid = raster_3d(state=state_component_t, pixel_indices=pixel_indices, res=input_dict['res'], origin=local_env_origin_t, h=self.local_env_h_rows, w=self.local_env_w_cols, c=self.local_env_c_channels, k=self.rope_image_k, batch_size=batch_size) local_voxel_grid_t_array = local_voxel_grid_t_array.write(i + 1, state_component_voxel_grid) local_voxel_grid_t = tf.transpose(local_voxel_grid_t_array.stack(), [1, 2, 3, 4, 0]) # add channel dimension information because tf.function erases it somehow... local_voxel_grid_t.set_shape([None, None, None, None, len(self.state_keys) + 1]) out_conv_z = self.fwd_conv(batch_size, local_voxel_grid_t) conv_outputs_array = conv_outputs_array.write(t, out_conv_z) if DEBUG_VIZ: debug_info_seq.append((state_t, local_env_origin_t, local_env_t, local_voxel_grid_t)) conv_outputs = conv_outputs_array.stack() return tf.transpose(conv_outputs, [1, 0, 2]), debug_info_seq