def check_constraint_batched_tf(self,
                                    environment: Dict,
                                    predictions: Dict,
                                    actions: Dict,
                                    batch_size: int,
                                    state_sequence_length: int):
        # construct network inputs
        net_inputs = {
            'batch_size': batch_size,
            'time':       state_sequence_length,
        }
        net_inputs.update(make_dict_tf_float32(environment))

        for action_key in self.action_keys:
            net_inputs[action_key] = tf.cast(actions[action_key], tf.float32)

        for state_key in self.state_keys:
            planned_state_key = add_predicted(state_key)
            net_inputs[planned_state_key] = tf.cast(predictions[state_key], tf.float32)

        if self.hparams['stdev']:
            net_inputs[add_predicted('stdev')] = tf.cast(predictions['stdev'], tf.float32)

        net_inputs = make_dict_tf_float32(net_inputs)
        mean_predictions, stdev_predictions = self.check_constraint_from_example(net_inputs, training=False)
        mean_probability = mean_predictions['probabilities']
        stdev_probability = stdev_predictions['probabilities']
        mean_probability = tf.squeeze(mean_probability, axis=2)
        stdev_probability = tf.squeeze(stdev_probability, axis=2)
        return mean_probability, stdev_probability
Пример #2
0
    def __init__(
        self,
        dataset_dirs: List[pathlib.Path],
        use_gt_rope: bool,
        load_true_states=False,
        no_balance=True,
        threshold: Optional[float] = None,
        old_compat: Optional[bool] = False,
    ):
        super(ClassifierDatasetLoader, self).__init__(dataset_dirs)
        self.no_balance = no_balance
        self.load_true_states = load_true_states
        self.use_gt_rope = use_gt_rope
        self.labeling_params = self.hparams['labeling_params']
        self.threshold = threshold if threshold is not None else self.labeling_params[
            'threshold']
        rospy.loginfo(f"classifier using threshold {self.threshold}")
        self.horizon = self.hparams['labeling_params']['classifier_horizon']
        self.scenario = get_scenario(self.hparams['scenario'])

        self.true_state_keys = self.hparams['true_state_keys']
        self.old_compat = old_compat
        if self.old_compat:
            self.true_state_keys.append('is_close')
        else:
            self.true_state_keys.append('error')

        self.predicted_state_keys = [
            add_predicted(k) for k in self.hparams['predicted_state_keys']
        ]
        self.predicted_state_keys.append(add_predicted('stdev'))

        self.action_keys = self.hparams['action_keys']

        self.feature_names = [
            'classifier_start_t',
            'classifier_end_t',
            'env',
            'extent',
            'origin',
            'res',
            'traj_idx',
            'prediction_start_t',
        ]

        self.batch_metadata = {'time': self.horizon}

        if self.load_true_states:
            for k in self.true_state_keys:
                self.feature_names.append(k)

        for k in self.predicted_state_keys:
            self.feature_names.append(k)

        for k in self.action_keys:
            self.feature_names.append(k)
    def __init__(self, hparams: Dict, batch_size: int, scenario: Base3DScenario):
        super().__init__(hparams, batch_size)
        self.scenario = scenario

        self.raster_debug_pubs = [
            rospy.Publisher(f'classifier_raster_debug_{i}', OccupancyStamped, queue_size=10, latch=False) for i in
            range(4)]
        self.local_env_bbox_pub = rospy.Publisher('local_env_bbox', BoundingBox, queue_size=10, latch=True)

        self.classifier_dataset_hparams = self.hparams['classifier_dataset_hparams']
        self.dynamics_dataset_hparams = self.classifier_dataset_hparams['fwd_model_hparams']['dynamics_dataset_hparams']
        self.true_state_keys = self.classifier_dataset_hparams['true_state_keys']
        self.pred_state_keys = [add_predicted(k) for k in self.classifier_dataset_hparams['predicted_state_keys']]
        self.pred_state_keys.append(add_predicted('stdev'))
        self.local_env_h_rows = self.hparams['local_env_h_rows']
        self.local_env_w_cols = self.hparams['local_env_w_cols']
        self.local_env_c_channels = self.hparams['local_env_c_channels']
        self.rope_image_k = self.hparams['rope_image_k']

        # TODO: add stdev to states keys?
        self.state_keys = self.hparams['state_keys']
        self.action_keys = self.hparams['action_keys']

        self.conv_layers = []
        self.pool_layers = []
        for n_filters, kernel_size in self.hparams['conv_filters']:
            conv = layers.Conv3D(n_filters,
                                 kernel_size,
                                 activation='relu',
                                 kernel_regularizer=keras.regularizers.l2(self.hparams['kernel_reg']),
                                 bias_regularizer=keras.regularizers.l2(self.hparams['bias_reg']))
            pool = layers.MaxPool3D(self.hparams['pooling'])
            self.conv_layers.append(conv)
            self.pool_layers.append(pool)

        if self.hparams['batch_norm']:
            self.batch_norm = layers.BatchNormalization()

        self.dense_layers = []
        for hidden_size in self.hparams['fc_layer_sizes']:
            dense = layers.Dense(hidden_size,
                                 activation='relu',
                                 kernel_regularizer=keras.regularizers.l2(self.hparams['kernel_reg']),
                                 bias_regularizer=keras.regularizers.l2(self.hparams['bias_reg']))
            self.dense_layers.append(dense)

        # self.local_env_shape = (self.local_env_h_rows, self.local_env_w_cols, self.local_env_c_channels)
        # self.encoder = tf.keras.applications.ResNet50(include_top=False, weights=None, input_shape=self.local_env_shape)

        self.lstm = layers.LSTM(self.hparams['rnn_size'], unroll=True, return_sequences=True)
        self.output_layer = layers.Dense(1, activation=None)
        self.sigmoid = layers.Activation("sigmoid")
    def plot_state_rviz(self, state: Dict, label: str, **kwargs):
        r, g, b, a = colors.to_rgba(kwargs.get("color", "r"))
        idx = kwargs.get("idx", 0)
        ig = marker_index_generator(idx)

        msg = MarkerArray()
        if rope_key_name in state:
            rope_points = np.reshape(state[rope_key_name], [-1, 3])
            rope_mrkrs = make_rope_marker(rope_points,
                                          'world',
                                          label + "_gt_" + rope_key_name,
                                          next(ig),
                                          r,
                                          g,
                                          b,
                                          a,
                                          s=0.04)
            points_marker, lines, midpoint_sphere, first_point_text = rope_mrkrs
            msg.markers.append(lines)
            # msg.markers.append(points_marker)
            # msg.markers.append(first_point_text)

        if 'gripper' in state:
            gripper = state['gripper']
            gripper_sphere = make_gripper_marker(gripper,
                                                 next(ig),
                                                 r,
                                                 g,
                                                 b,
                                                 a,
                                                 label + '_gt_gripper',
                                                 Marker.SPHERE,
                                                 s=0.04)
            msg.markers.append(gripper_sphere)

        if add_predicted(rope_key_name) in state:
            rope_points = np.reshape(state[add_predicted(rope_key_name)],
                                     [-1, 3])
            markers = make_rope_marker(rope_points,
                                       'world', label + "_" + rope_key_name,
                                       next(ig), r, g, b, a)
            msg.markers.extend(markers)

        if add_predicted('gripper') in state:
            pred_gripper = state[add_predicted('gripper')]
            gripper_sphere = make_gripper_marker(pred_gripper, next(ig), r, g,
                                                 b, a, label + "_gripper",
                                                 Marker.SPHERE)
            msg.markers.append(gripper_sphere)

        self.state_viz_pub.publish(msg)
    def call(self, input_dict: Dict, training, **kwargs):
        batch_size = input_dict['batch_size']
        time = tf.cast(input_dict['time'], tf.int32)

        conv_output, debug_info_seq = self.make_traj_voxel_grids_from_input_dict(input_dict, batch_size, time)

        states = {k: input_dict[add_predicted(k)] for k in self.state_keys}
        states_in_local_frame = self.scenario.put_state_local_frame(states)
        actions = {k: input_dict[k] for k in self.action_keys}
        all_but_last_states = {k: v[:, :-1] for k, v in states.items()}
        actions = self.scenario.put_action_local_frame(all_but_last_states, actions)
        padded_actions = [tf.pad(v, [[0, 0], [0, 1], [0, 0]]) for v in actions.values()]
        if 'with_robot_frame' not in self.hparams:
            print("no hparam 'with_robot_frame'. This must be an old model!")
            concat_args = [conv_output] + list(states_in_local_frame.values()) + padded_actions
        elif self.hparams['with_robot_frame']:
            states_in_robot_frame = self.scenario.put_state_robot_frame(states)
            concat_args = ([conv_output] + list(states_in_robot_frame.values()) +
                           list(states_in_local_frame.values()) + padded_actions)
        else:
            concat_args = [conv_output] + list(states_in_local_frame.values()) + padded_actions

        if self.hparams['stdev']:
            stdevs = input_dict[add_predicted('stdev')]
            concat_args.append(stdevs)

        concat_output = tf.concat(concat_args, axis=2)

        if self.hparams['batch_norm']:
            concat_output = self.batch_norm(concat_output, training=training)

        z = concat_output
        for dense_layer in self.dense_layers:
            z = dense_layer(z)
        out_d = z

        out_h = self.lstm(out_d)

        # for every timestep's output, map down to a single scalar, the logit for accept probability
        all_accept_logits = self.output_layer(out_h)
        # ignore the first output, it is meaningless to predict the validity of a single state
        valid_accept_logits = all_accept_logits[:, 1:]
        valid_accept_probabilities = self.sigmoid(valid_accept_logits)

        if DEBUG_VIZ:
            self.debug_rviz(input_dict, debug_info_seq)

        return {
            'logits':        valid_accept_logits,
            'probabilities': valid_accept_probabilities,
        }
Пример #6
0
 def state_encoder(self, input_dict, batch_size, training):
     # get only the start states
     start_state = {k: input_dict[add_predicted(k)][:, 0] for k in self.states_keys}
     # tile to the number of actions
     local_env_center_point = self.scenario.local_environment_center_differentiable(start_state)
     images = self.make_trajectory_images(environment=self.scenario.get_environment_from_example(input_dict),
                                          start_states=start_state,
                                          local_env_center_point=local_env_center_point,
                                          batch_size=batch_size)
     # import matplotlib.pyplot as plt
     # from matplotlib import cm
     # cmap = cm.viridis
     # out_image = state_image_to_cmap(images[0], cmap=cmap)
     # plt.imshow(out_image)
     # plt.show()
     conv_output = self._conv(images, batch_size)
     concat_args = [conv_output]
     for k, v in start_state.items():
         # note this assumes all state vectors are[x1,y1,...,xn,yn]
         points = tf.reshape(v, [batch_size, -1, 2])
         points = points - points[:, :, tf.newaxis, 0]
         v = tf.reshape(points, [batch_size, -1])
         concat_args.append(v)
     conv_output = tf.concat(concat_args, axis=1)
     if self.hparams['batch_norm']:
         conv_output = self.batch_norm(conv_output, training=training)
     z = conv_output
     for dense_layer in self.env_state_encoder_dense_layers:
         z = dense_layer(z)
     return images, z
Пример #7
0
 def sample(self, environment: Dict, state: Dict):
     input_dict = environment
     input_dict.update({add_predicted(k): tf.expand_dims(v, axis=0) for k, v in state.items()})
     input_dict = add_batch(input_dict)
     input_dict = {k: tf.cast(v, tf.float32) for k, v in input_dict.items()}
     output = self.net.sample(input_dict)
     output = remove_batch(output)
     output = numpify(output)
     return output
 def test_add_remove_predicted_dict(self):
     d = {
         add_predicted("test1"): 1,
         "test2": 2,
     }
     expected_d = {
         "test1": 1,
         "test2": 2,
     }
     out_d = remove_predicted_from_dict(d)
     self.assertEqual(expected_d, out_d)
Пример #9
0
 def _process_example(dataset: ClassifierDatasetLoader, example: Dict):
     example['left_gripper'] = example.pop('gripper1')
     example['right_gripper'] = example.pop('gripper2')
     example['left_gripper_position'] = example.pop('gripper1_position')
     example['right_gripper_position'] = example.pop('gripper2_position')
     example['rope'] = example.pop('link_bot')
     example[add_predicted('left_gripper')] = example.pop(
         add_predicted('gripper1'))
     example[add_predicted('right_gripper')] = example.pop(
         add_predicted('gripper2'))
     example[add_predicted('rope')] = example.pop(add_predicted('link_bot'))
     yield example
def visualize_dataset(args, classifier_dataset):
    tf_dataset = classifier_dataset.get_datasets(mode=args.mode, take=args.take)

    tf_dataset = tf_dataset.batch(1)

    iterator = iter(tf_dataset)
    t0 = perf_counter()

    reconverging_count = 0
    positive_count = 0
    negative_count = 0
    count = 0

    stdevs = []
    labels = []
    stdevs_for_negative = []
    stdevs_for_positive = []

    for i, example in enumerate(progressbar(tf_dataset, widgets=base_dataset.widgets)):
        example = remove_batch(example)

        is_close = example['is_close'].numpy().squeeze()
        count += is_close.shape[0]

        n_close = np.count_nonzero(is_close[-1])
        n_far = is_close.shape[0] - n_close
        positive_count += n_close
        negative_count += n_far
        reconverging = n_far > 0 and is_close[-1]

        if args.only_reconverging and not reconverging:
            continue

        if args.only_negative and np.any(is_close[1:]):
            continue

        if args.only_positive and not np.any(is_close[1:]):
            continue

        # print(f"Example {i}, Trajectory #{int(example['traj_idx'])}")

        if count == 0:
            print_dict(example)

        if reconverging:
            reconverging_count += 1

        # Print statistics intermittently
        if count % 1000 == 0:
            print_stats_and_timing(args, count, reconverging_count, negative_count, positive_count)

        #############################
        # Show Visualization
        #############################
        if args.display_type == 'just_count':
            continue
        elif args.display_type == '3d':
            # print(example['is_close'])
            if example['is_close'][0] == 0:
                continue
            classifier_dataset.anim_transition_rviz(example)

        elif args.display_type == 'stdev':
            for t in range(1, classifier_dataset.horizon):
                stdev_t = example[add_predicted('stdev')][t, 0].numpy()
                label_t = example['is_close'][t]
                stdevs.append(stdev_t)
                labels.append(label_t)
                if label_t > 0.5:
                    stdevs_for_positive.append(stdev_t)
                else:
                    stdevs_for_negative.append(stdev_t)
        else:
            raise NotImplementedError()
    total_dt = perf_counter() - t0

    if args.display_type == 'stdev':
        print(f"p={stats.f_oneway(stdevs_for_negative, stdevs_for_positive)[1]}")

        plt.figure()
        plt.title(" ".join([str(d.name) for d in args.dataset_dirs]))
        bins = plt.hist(stdevs_for_negative, label='negative examples', alpha=0.8, density=True)[1]
        plt.hist(stdevs_for_positive, label='positive examples', alpha=0.8, bins=bins, density=True)
        plt.ylabel("count")
        plt.xlabel("stdev")
        plt.legend()
        plt.show()

    print_stats_and_timing(args, count, reconverging_count, negative_count, positive_count, total_dt)
def viz_ensemble_main(dataset_dir: pathlib.Path,
                      checkpoints: List[pathlib.Path], mode: str,
                      batch_size: int, only_errors: bool, use_gt_rope: bool,
                      **kwargs):
    dynamics_stdev_pub_ = rospy.Publisher("dynamics_stdev",
                                          Float32,
                                          queue_size=10)
    classifier_stdev_pub_ = rospy.Publisher("classifier_stdev",
                                            Float32,
                                            queue_size=10)
    accept_probability_pub_ = rospy.Publisher("accept_probability_viz",
                                              Float32,
                                              queue_size=10)
    traj_idx_pub_ = rospy.Publisher("traj_idx_viz", Float32, queue_size=10)

    ###############
    # Model
    ###############
    model = load_generic_model(checkpoints)

    ###############
    # Dataset
    ###############
    test_dataset = ClassifierDatasetLoader([dataset_dir],
                                           load_true_states=True,
                                           use_gt_rope=use_gt_rope)
    test_tf_dataset = test_dataset.get_datasets(mode=mode)
    test_tf_dataset = batch_tf_dataset(test_tf_dataset,
                                       batch_size,
                                       drop_remainder=True)
    scenario = test_dataset.scenario

    ###############
    # Evaluate
    ###############

    # Iterate over test set
    all_accuracies_over_time = []
    all_stdevs = []
    all_labels = []
    classifier_ensemble_stdevs = []
    for batch_idx, test_batch in enumerate(test_tf_dataset):
        test_batch.update(test_dataset.batch_metadata)

        mean_predictions, stdev_predictions = model.check_constraint_from_example(
            test_batch)
        mean_probabilities = mean_predictions['probabilities']
        stdev_probabilities = stdev_predictions['probabilities']

        labels = tf.expand_dims(test_batch['is_close'][:, 1:], axis=2)

        all_labels = tf.concat(
            (all_labels, tf.reshape(test_batch['is_close'][:, 1:], [-1])),
            axis=0)
        all_stdevs = tf.concat(
            (all_stdevs, tf.reshape(test_batch[add_predicted('stdev')], [-1])),
            axis=0)

        accuracy_over_time = tf.keras.metrics.binary_accuracy(
            y_true=labels, y_pred=mean_probabilities)
        all_accuracies_over_time.append(accuracy_over_time)

        # Visualization
        test_batch.pop("time")
        test_batch.pop("batch_size")
        decisions = mean_probabilities > 0.5
        classifier_is_correct = tf.squeeze(tf.equal(decisions,
                                                    tf.cast(labels, tf.bool)),
                                           axis=-1)
        for b in range(batch_size):
            example = index_dict_of_batched_tensors_tf(test_batch, b)

            classifier_ensemble_stdev = stdev_probabilities[b].numpy().squeeze(
            )
            classifier_ensemble_stdevs.append(classifier_ensemble_stdev)

            # if the classifier is correct at all time steps, ignore
            if only_errors and tf.reduce_all(classifier_is_correct[b]):
                continue

            # if only_collision
            predicted_rope_states = tf.reshape(
                example[add_predicted('link_bot')][1], [-1, 3])
            xs = predicted_rope_states[:, 0]
            ys = predicted_rope_states[:, 1]
            zs = predicted_rope_states[:, 2]
            in_collision = bool(
                batch_in_collision_tf_3d(environment=example,
                                         xs=xs,
                                         ys=ys,
                                         zs=zs,
                                         inflate_radius_m=0)[0].numpy())
            label = bool(example['is_close'][1].numpy())
            accept = decisions[b, 0, 0].numpy()
            # if not (in_collision and accept):
            #     continue

            scenario.plot_environment_rviz(example)

            stdev_probabilities[b].numpy().squeeze()
            classifier_stdev_msg = Float32()
            classifier_stdev_msg.data = stdev_probabilities[b].numpy().squeeze(
            )
            classifier_stdev_pub_.publish(classifier_stdev_msg)

            actual_0 = scenario.index_state_time(example, 0)
            actual_1 = scenario.index_state_time(example, 1)
            pred_0 = scenario.index_predicted_state_time(example, 0)
            pred_1 = scenario.index_predicted_state_time(example, 1)
            action = scenario.index_action_time(example, 0)
            label = example['is_close'][1]
            scenario.plot_state_rviz(actual_0,
                                     label='actual',
                                     color='#FF0000AA',
                                     idx=0)
            scenario.plot_state_rviz(actual_1,
                                     label='actual',
                                     color='#E00016AA',
                                     idx=1)
            scenario.plot_state_rviz(pred_0,
                                     label='predicted',
                                     color='#0000FFAA',
                                     idx=0)
            scenario.plot_state_rviz(pred_1,
                                     label='predicted',
                                     color='#0553FAAA',
                                     idx=1)
            scenario.plot_action_rviz(pred_0, action)
            scenario.plot_is_close(label)

            dynamics_stdev_t = example[add_predicted('stdev')][1, 0].numpy()
            dynamics_stdev_msg = Float32()
            dynamics_stdev_msg.data = dynamics_stdev_t
            dynamics_stdev_pub_.publish(dynamics_stdev_msg)

            accept_probability_t = mean_probabilities[b, 0, 0].numpy()
            accept_probability_msg = Float32()
            accept_probability_msg.data = accept_probability_t
            accept_probability_pub_.publish(accept_probability_msg)

            traj_idx_msg = Float32()
            traj_idx_msg.data = batch_idx * batch_size + b
            traj_idx_pub_.publish(traj_idx_msg)

            # stepper = RvizSimpleStepper()
            # stepper.step()

        print(np.mean(classifier_ensemble_stdevs))

    all_accuracies_over_time = tf.concat(all_accuracies_over_time, axis=0)
    mean_accuracies_over_time = tf.reduce_mean(all_accuracies_over_time,
                                               axis=0)
    std_accuracies_over_time = tf.math.reduce_std(all_accuracies_over_time,
                                                  axis=0)
    mean_classifier_ensemble_stdev = tf.reduce_mean(classifier_ensemble_stdevs)
    print(mean_classifier_ensemble_stdev)
Пример #12
0
    def plot_state_rviz(self, state: Dict, label: str, **kwargs):
        r, g, b, a = colors.to_rgba(kwargs.get("color", "r"))
        idx = kwargs.get("idx", 0)

        msg = MarkerArray()

        ig = marker_index_generator(idx)

        if 'gt_rope' in state:
            rope_points = np.reshape(state['gt_rope'], [-1, 3])
            markers = make_rope_marker(rope_points, 'world',
                                       label + "_gt_rope", next(ig), r, g, b,
                                       a)
            msg.markers.extend(markers)

        if 'rope' in state:
            rope_points = np.reshape(state['rope'], [-1, 3])
            markers = make_rope_marker(rope_points, 'world', label + "_rope",
                                       next(ig), r, g, b, a)
            msg.markers.extend(markers)

        if add_predicted('rope') in state:
            rope_points = np.reshape(state[add_predicted('rope')], [-1, 3])
            markers = make_rope_marker(rope_points, 'world',
                                       label + "_pred_rope", next(ig), r, g, b,
                                       a, Marker.CUBE_LIST)
            msg.markers.extend(markers)

        if 'left_gripper' in state:
            r = 0.2
            g = 0.2
            b = 0.8
            left_gripper_sphere = make_gripper_marker(state['left_gripper'],
                                                      next(ig), r, g, b, a,
                                                      label + '_l',
                                                      Marker.SPHERE)
            msg.markers.append(left_gripper_sphere)

        if 'right_gripper' in state:
            r = 0.8
            g = 0.8
            b = 0.2
            right_gripper_sphere = make_gripper_marker(state['right_gripper'],
                                                       next(ig), r, g, b, a,
                                                       label + "_r",
                                                       Marker.SPHERE)
            msg.markers.append(right_gripper_sphere)

        if add_predicted('left_gripper') in state:
            r = 0.2
            g = 0.2
            b = 0.8
            lgpp = state[add_predicted('left_gripper')]
            left_gripper_sphere = make_gripper_marker(lgpp, next(ig), r, g, b,
                                                      a, label + "_lp",
                                                      Marker.CUBE)
            msg.markers.append(left_gripper_sphere)

        if add_predicted('right_gripper') in state:
            r = 0.8
            g = 0.8
            b = 0.2
            rgpp = state[add_predicted('right_gripper')]
            right_gripper_sphere = make_gripper_marker(rgpp, next(ig), r, g, b,
                                                       a, label + "_rp",
                                                       Marker.CUBE)
            msg.markers.append(right_gripper_sphere)

        s = kwargs.get("scale", 1.0)
        msg = scale_marker_array(msg, s)

        self.state_viz_pub.publish(msg)

        if in_maybe_predicted('rgbd', state):
            publish_color_image(self.state_color_viz_pub,
                                state['rgbd'][:, :, :3])
            publish_depth_image(self.state_depth_viz_pub, state['rgbd'][:, :,
                                                                        3])

        if add_predicted('stdev') in state:
            stdev_t = state[add_predicted('stdev')][0]
            self.plot_stdev(stdev_t)

        if 'error' in state:
            error_msg = Float32()
            error_t = state['error']
            error_msg.data = error_t
            self.error_pub.publish(error_msg)
 def test_add_remove_predicted(self):
     k = "test"
     out_k = remove_predicted(add_predicted(k))
     self.assertEqual(k, out_k)
def stdev_viz_t(pub: rospy.Publisher):
    return float32_viz_t(pub, add_predicted('stdev'))
def generate_classifier_examples_from_batch(
        scenario: ExperimentScenario,
        prediction_actual: PredictionActualExample):
    labeling_params = prediction_actual.labeling_params
    prediction_horizon = prediction_actual.actual_prediction_horizon
    classifier_horizon = labeling_params['classifier_horizon']

    valid_out_examples = []
    for classifier_start_t in range(
            0, prediction_horizon - classifier_horizon + 1):
        classifier_end_t = classifier_start_t + classifier_horizon

        prediction_start_t = prediction_actual.prediction_start_t
        prediction_start_t_batched = tf.cast(
            tf.stack([prediction_start_t] * prediction_actual.batch_size,
                     axis=0), tf.float32)
        classifier_start_t_batched = tf.cast(
            tf.stack([classifier_start_t] * prediction_actual.batch_size,
                     axis=0), tf.float32)
        classifier_end_t_batched = tf.cast(
            tf.stack([classifier_end_t] * prediction_actual.batch_size,
                     axis=0), tf.float32)
        out_example = {
            'env': prediction_actual.dataset_element['env'],
            'origin': prediction_actual.dataset_element['origin'],
            'extent': prediction_actual.dataset_element['extent'],
            'res': prediction_actual.dataset_element['res'],
            'traj_idx': prediction_actual.dataset_element['traj_idx'],
            'prediction_start_t': prediction_start_t_batched,
            'classifier_start_t': classifier_start_t_batched,
            'classifier_end_t': classifier_end_t_batched,
        }

        # this slice gives arrays of fixed length (ex, 5) which must be null padded from out_example_end_idx onwards
        state_slice = slice(classifier_start_t,
                            classifier_start_t + classifier_horizon)
        action_slice = slice(classifier_start_t,
                             classifier_start_t + classifier_horizon - 1)
        sliced_actual = {}
        for key, actual_state_component in prediction_actual.actual_states.items(
        ):
            actual_state_component_sliced = actual_state_component[:,
                                                                   state_slice]
            out_example[key] = actual_state_component_sliced
            sliced_actual[key] = actual_state_component_sliced

        sliced_predictions = {}
        for key, prediction_component in prediction_actual.predictions.items():
            prediction_component_sliced = prediction_component[:, state_slice]
            out_example[add_predicted(key)] = prediction_component_sliced
            sliced_predictions[key] = prediction_component_sliced

        # action
        sliced_actions = {}
        for key, action_component in prediction_actual.actions.items():
            action_component_sliced = action_component[:, action_slice]
            out_example[key] = action_component_sliced
            sliced_actions[key] = action_component_sliced

        # compute label
        threshold = labeling_params['threshold']
        error = scenario.classifier_distance(sliced_actual, sliced_predictions)
        is_close = error < threshold
        out_example['error'] = tf.cast(error, dtype=tf.float32)

        # perception reliability
        if 'perception_reliability_method' in labeling_params:
            pr_method = labeling_params['perception_reliability_method']
            if pr_method == 'gt':
                perception_reliability = gt_perception_reliability(
                    scenario, sliced_actual, sliced_predictions)
                out_example['perception_reliability'] = perception_reliability
            else:
                raise NotImplementedError(
                    f"unrecognized perception reliability method {pr_method}")

        is_first_predicted_state_close = is_close[:, 0]
        valid_indices = tf.where(is_first_predicted_state_close)
        valid_indices = tf.squeeze(valid_indices, axis=1)
        # keep only valid_indices from every key in out_example...
        valid_out_example = gather_dict(out_example, valid_indices)
        valid_out_examples.append(valid_out_example)
        # valid_out_examples.append(out_example)
    return valid_out_examples
    def make_traj_voxel_grids_from_input_dict(self, input_dict: Dict, batch_size, time: int):
        # Construct a [b, h, w, c, 3] grid of the indices which make up the local environment
        pixel_row_indices = tf.range(0, self.local_env_h_rows, dtype=tf.float32)
        pixel_col_indices = tf.range(0, self.local_env_w_cols, dtype=tf.float32)
        pixel_channel_indices = tf.range(0, self.local_env_c_channels, dtype=tf.float32)
        x_indices, y_indices, z_indices = tf.meshgrid(pixel_col_indices, pixel_row_indices, pixel_channel_indices)

        # Make batched versions for creating the local environment
        batch_y_indices = tf.cast(tf.tile(tf.expand_dims(y_indices, axis=0), [batch_size, 1, 1, 1]), tf.int64)
        batch_x_indices = tf.cast(tf.tile(tf.expand_dims(x_indices, axis=0), [batch_size, 1, 1, 1]), tf.int64)
        batch_z_indices = tf.cast(tf.tile(tf.expand_dims(z_indices, axis=0), [batch_size, 1, 1, 1]), tf.int64)

        # Convert for rastering state
        pixel_indices = tf.stack([y_indices, x_indices, z_indices], axis=3)
        pixel_indices = tf.expand_dims(pixel_indices, axis=0)
        pixel_indices = tf.tile(pixel_indices, [batch_size, 1, 1, 1, 1])

        conv_outputs_array = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
        debug_info_seq = []
        for t in tf.range(time):
            state_t = {k: input_dict[add_predicted(k)][:, t] for k in self.state_keys}

            local_env_center_t = self.scenario.local_environment_center_differentiable(state_t)
            # by converting too and from the frame of the full environment, we ensure the grids are aligned
            indices = batch_point_to_idx_tf_3d_in_batched_envs(local_env_center_t, input_dict)
            local_env_center_t = batch_idx_to_point_3d_in_env_tf(*indices, input_dict)

            local_env_t, local_env_origin_t = get_local_env(center_point=local_env_center_t,
                                                            full_env=input_dict['env'],
                                                            full_env_origin=input_dict['origin'],
                                                            res=input_dict['res'],
                                                            local_h_rows=self.local_env_h_rows,
                                                            local_w_cols=self.local_env_w_cols,
                                                            local_c_channels=self.local_env_c_channels,
                                                            batch_x_indices=batch_x_indices,
                                                            batch_y_indices=batch_y_indices,
                                                            batch_z_indices=batch_z_indices,
                                                            batch_size=batch_size)

            local_voxel_grid_t_array = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
            local_voxel_grid_t_array = local_voxel_grid_t_array.write(0, local_env_t)
            for i, (k, state_component_t) in enumerate(state_t.items()):
                state_component_voxel_grid = raster_3d(state=state_component_t,
                                                       pixel_indices=pixel_indices,
                                                       res=input_dict['res'],
                                                       origin=local_env_origin_t,
                                                       h=self.local_env_h_rows,
                                                       w=self.local_env_w_cols,
                                                       c=self.local_env_c_channels,
                                                       k=self.rope_image_k,
                                                       batch_size=batch_size)

                local_voxel_grid_t_array = local_voxel_grid_t_array.write(i + 1, state_component_voxel_grid)
            local_voxel_grid_t = tf.transpose(local_voxel_grid_t_array.stack(), [1, 2, 3, 4, 0])
            # add channel dimension information because tf.function erases it somehow...
            local_voxel_grid_t.set_shape([None, None, None, None, len(self.state_keys) + 1])

            out_conv_z = self.fwd_conv(batch_size, local_voxel_grid_t)

            conv_outputs_array = conv_outputs_array.write(t, out_conv_z)

            if DEBUG_VIZ:
                debug_info_seq.append((state_t, local_env_origin_t, local_env_t, local_voxel_grid_t))

        conv_outputs = conv_outputs_array.stack()
        return tf.transpose(conv_outputs, [1, 0, 2]), debug_info_seq