Пример #1
0
    def test_raises_error_when_seconds_negative(self):
        mock_samples = [{'token': '1', 'timestamp': 0, 'anns': ['1', '1b']}]
        nusc = MockNuScenes(self.mock_annotations, mock_samples)
        helper = PredictHelper(nusc)
        with self.assertRaises(ValueError):
            helper.get_future_for_agent('1', '1', -1, False)

        with self.assertRaises(ValueError):
            helper.get_past_for_agent('1', '1', -1, False)

        with self.assertRaises(ValueError):
            helper.get_past_for_sample('1', -1, False)

        with self.assertRaises(ValueError):
            helper.get_future_for_sample('1', -1, False)
Пример #2
0
    def test_get_past_for_agent_in_frame(self, ):

        mock_samples = [{
            'token': '5',
            'timestamp': 0
        }, {
            'token': '4',
            'timestamp': -1e6
        }, {
            'token': '3',
            'timestamp': -2e6
        }, {
            'token': '2',
            'timestamp': -3e6
        }, {
            'token': '1',
            'timestamp': -4e6
        }]

        # Testing we can get the exact amount of past seconds available
        nusc = MockNuScenes(self.mock_annotations, mock_samples)
        helper = PredictHelper(nusc)
        past = helper.get_past_for_agent('1', '5', 3, True)
        np.testing.assert_allclose(past,
                                   np.array([[1., -1.], [2., -2.], [3., -3.]]))
Пример #3
0
    def test_get_past_for_last_returns_nothing(self):
        mock_samples = [{'token': '1', 'timestamp': 0}]

        # Testing we get nothing if we're at the last annotation
        nusc = MockNuScenes(self.mock_annotations, mock_samples)
        helper = PredictHelper(nusc)
        past = helper.get_past_for_agent('1', '1', 3, False)
        np.testing.assert_equal(past, np.array([]))
Пример #4
0
    def test_get_no_data_when_seconds_0(self):
        mock_samples = [{'token': '1', 'timestamp': 0, 'anns': ['1']}]
        nusc = MockNuScenes(self.mock_annotations, mock_samples)
        helper = PredictHelper(nusc)

        np.testing.assert_equal(
            helper.get_future_for_agent('1', '1', 0, False), np.array([]))
        np.testing.assert_equal(helper.get_past_for_agent('1', '1', 0, False),
                                np.array([]))
        np.testing.assert_equal(helper.get_future_for_sample('1', 0, False),
                                np.array([]))
        np.testing.assert_equal(helper.get_past_for_sample('1', 0, False),
                                np.array([]))
Пример #5
0
    def test_get_past_for_agent_no_data_to_get(self, ):
        mock_samples = [{
            'token': '5',
            'timestamp': 0
        }, {
            'token': '4',
            'timestamp': -3.5e6
        }]

        # Testing we get nothing if the first sample annotation is past our threshold
        nusc = MockNuScenes(self.mock_annotations, mock_samples)
        helper = PredictHelper(nusc)
        past = helper.get_past_for_agent('1', '5', 3, False)
        np.testing.assert_equal(past, np.array([]))
Пример #6
0
    def test_get_past_for_agent_within_buffer(self, ):

        mock_samples = [{
            'token': '5',
            'timestamp': 0
        }, {
            'token': '4',
            'timestamp': -1e6
        }, {
            'token': '3',
            'timestamp': -3.05e6
        }, {
            'token': '2',
            'timestamp': -3.2e6
        }]

        # Testing we get data if it is after future seconds but within buffer
        nusc = MockNuScenes(self.mock_annotations, mock_samples)
        helper = PredictHelper(nusc)
        past = helper.get_past_for_agent('1', '5', 3, False)
        np.testing.assert_equal(past, np.array([[3, 3], [2, 2]]))
Пример #7
0
    def test_get_past_for_agent_less_amount(self, ):

        mock_samples = [{
            'token': '5',
            'timestamp': 0
        }, {
            'token': '4',
            'timestamp': -1e6
        }, {
            'token': '3',
            'timestamp': -2.6e6
        }, {
            'token': '2',
            'timestamp': -4e6
        }, {
            'token': '1',
            'timestamp': -5.5e6
        }]

        # Testing we do not include data after the past seconds
        nusc = MockNuScenes(self.mock_annotations, mock_samples)
        helper = PredictHelper(nusc)
        past = helper.get_past_for_agent('1', '5', 3, False)
        np.testing.assert_equal(past, np.array([[3, 3], [2, 2]]))
Пример #8
0
    def test_get_past_for_agent_exact_amount(self, ):

        mock_samples = [{
            'token': '5',
            'timestamp': 0
        }, {
            'token': '4',
            'timestamp': -1e6
        }, {
            'token': '3',
            'timestamp': -2e6
        }, {
            'token': '2',
            'timestamp': -3e6
        }, {
            'token': '1',
            'timestamp': -4e6
        }]

        # Testing we can get the exact amount of past seconds available
        nusc = MockNuScenes(self.mock_annotations, mock_samples)
        helper = PredictHelper(nusc)
        past = helper.get_past_for_agent('1', '5', 3, False)
        np.testing.assert_equal(past, np.array([[3, 3], [2, 2], [1, 1]]))
Пример #9
0
def compute_metrics(predictions: List[Dict[str, Any]], helper: PredictHelper,
                    config: PredictionConfig) -> Dict[str, Any]:
    """
    Computes metrics from a set of output.
    :param predictions: List of prediction JSON objects.
    :param helper: Instance of PredictHelper that wraps the nuScenes val set.
    :param config: Config file.
    :return: Metrics. Nested dictionary where keys are metric names and value is a dictionary
        mapping the Aggregator name to the results.
    """
    n_preds = len(predictions)
    containers = {
        metric.name: np.zeros((n_preds, metric.shape))
        for metric in config.metrics
    }
    containers['imginfo'] = []
    containers['tokens'] = []

    BACK_SAMPLES = 5

    for i, prediction_str in enumerate(tqdm(predictions)):
        prediction = Prediction.deserialize(prediction_str)
        ground_truth = helper.get_future_for_agent(prediction.instance,
                                                   prediction.sample,
                                                   config.seconds,
                                                   in_agent_frame=False)
        agent_past = helper.get_past_for_agent(prediction.instance,
                                               prediction.sample,
                                               BACK_SAMPLES * 0.5,
                                               in_agent_frame=False,
                                               just_xy=False)

        if len(agent_past) < BACK_SAMPLES:
            print('Sample {} didnt have enough history {}'.format(
                i, len(agent_past)))
            continue

        cam_names = [
            sensor['channel'] for sensor in nusc.sensor
            if 'CAM' in sensor['channel']
        ]
        containers['tokens'].append(prediction_str)
        cameras = []
        containers['imginfo'].append(cameras)

        # Append current timestep
        d = {}
        cameras.append(d)
        token = prediction.instance + '_' + prediction.sample
        for cam_name in cam_names:
            d[cam_name] = get_im_and_box(
                nusc, token, cam_name=cam_name,
                imgAsName=True)  # impath, box, camera_intrinsics

        # Append earlier timesteps
        for t in range(BACK_SAMPLES):
            d = {}
            cameras.append(d)
            token = agent_past[t]['instance_token'] + '_' + agent_past[t][
                'sample_token']
            for cam_name in cam_names:
                d[cam_name] = get_im_and_box(
                    nusc, token, cam_name=cam_name,
                    imgAsName=True)  # impath, box, camera_intrinsics

        for metric in config.metrics:
            containers[metric.name][i] = metric(ground_truth, prediction)

    aggregations: Dict[str, Dict[str, List[float]]] = defaultdict(dict)
    for metric in config.metrics:
        for agg in metric.aggregators:
            aggregations[metric.name][agg.name] = agg(containers[metric.name])
    return aggregations, containers
Пример #10
0
class NuScenesFormatTransformer:
    def __init__(self,
                 DATAROOT='./data/sets/nuscenes',
                 dataset_version='v1.0-mini'):
        self.DATAROOT = DATAROOT
        self.dataset_version = dataset_version
        self.nuscenes = NuScenes(dataset_version, dataroot=self.DATAROOT)
        self.helper = PredictHelper(self.nuscenes)

    def get_format_mha_jam(self,
                           samples_agents,
                           out_file="./transformer_format.txt"):
        instance_token_to_id_dict = {}
        sample_token_to_id_dict = {}

        scene_token_dict = {}
        sample_id = 0
        instance_id = 0

        for current_sample in tqdm(samples_agents):
            instance_token, sample_token = current_sample.split("_")
            scene_token = self.nuscenes.get('sample',
                                            sample_token)["scene_token"]

            if scene_token in scene_token_dict:
                continue

            # get the first sample in this sequence
            scene_token_dict[scene_token] = True
            first_sample_token = self.nuscenes.get(
                "scene", scene_token)["first_sample_token"]
            current_sample = self.nuscenes.get('sample', first_sample_token)

            while True:
                if current_sample['token'] not in sample_token_to_id_dict:
                    sample_token_to_id_dict[
                        current_sample['token']] = sample_id
                    sample_token_to_id_dict[sample_id] = current_sample[
                        'token']
                    sample_id += 1
                else:
                    print("should not happen?")

                instances_in_sample = self.helper.get_annotations_for_sample(
                    current_sample['token'])

                for sample_instance in instances_in_sample:
                    if sample_instance[
                            'instance_token'] not in instance_token_to_id_dict:
                        instance_token_to_id_dict[
                            sample_instance['instance_token']] = instance_id
                        instance_token_to_id_dict[
                            instance_id] = sample_instance['instance_token']
                        instance_id += 1

                if current_sample['next'] == "":
                    break

                current_sample = self.nuscenes.get('sample',
                                                   current_sample['next'])

        mode = "train" if out_file.find("_train") != -1 else "val"
        mini = "mini" if out_file.find("mini") != -1 else "main"

        with open(
                "dicts_sample_and_instances_id2token_" + mode + "_" + mini +
                ".json", 'w') as fw:
            json.dump([instance_token_to_id_dict, sample_token_to_id_dict], fw)
        #############
        # Converting to the transformer network format
        # frame_id, agent_id, pos_x, pos_y
        # todo:
        # loop on all the agents, if agent not taken:
        # 1- add it to takens agents (do not retake the agent)
        # 2- get the number of appearance of this agent
        # 3- skip this agent if the number is less than 10s (4 + 6)
        # 4- get the middle agent's token
        # 5- get the past and future agent's locations relative to its location
        samples_new_format = []
        taken_instances = {}
        ds_size = 0
        # max_past_traj_len = -1

        for current_sample in samples_agents:
            instance_token, sample_token = current_sample.split("_")
            instance_id = instance_token_to_id_dict[instance_token]

            if instance_id in taken_instances:
                continue

            taken_instances[instance_id] = True

            # trajectory_full_instances = self.get_trajectory_around_sample(instance_token, sample_token,
            #                                                               just_xy=False)

            # //////////////////////
            future_samples = self.helper.get_future_for_agent(
                instance_token, sample_token, 6, True, False)
            past_samples = self.helper.get_past_for_agent(
                instance_token, sample_token, 1000, True,
                False)[:MAX_TRAJ_LEN - 1][::-1]

            current_sample = self.helper.get_sample_annotation(
                instance_token, sample_token)
            assert len(past_samples) >= 1
            assert len(future_samples) == 12

            # assert len(past_samples) < 7
            # if len(past_samples) > max_past_traj_len:
            #     max_past_traj_len = len(past_samples)

            # past_samples = np.append(past_samples, [current_sample], axis=0)

            ds_size += 1

            # get_trajectory at this position
            center_pos = len(past_samples)
            future_samples_local = self.helper.get_future_for_agent(
                instance_token, sample_token, 6, True, True)
            past_samples_local = self.helper.get_past_for_agent(
                instance_token, sample_token, 1000, True,
                True)[:MAX_TRAJ_LEN - 1][::-1]
            # current_sample = self.helper.get_sample_annotation(instance_token, sample_token)
            assert len(future_samples_local) == 12

            # if len(past_samples) > 7:
            #     past_samples = past_samples[len(past_samples)-7:]
            #     past_samples_local = past_samples_local[past_samples_local.shape[0]-7:]

            trajectory = np.append(past_samples_local,
                                   np.append([[0, 0]],
                                             future_samples_local,
                                             axis=0),
                                   axis=0)

            past_samples = [
                sample_token_to_id_dict[p['sample_token']]
                for p in past_samples
            ]
            future_samples = [
                sample_token_to_id_dict[p['sample_token']]
                for p in future_samples
            ]
            trajectory_tokens = np.append(
                past_samples,
                np.append([sample_token_to_id_dict[sample_token]],
                          future_samples,
                          axis=0),
                axis=0)

            trajectory_ = np.zeros((trajectory.shape[0], 6))
            trajectory_[:, 0] = trajectory_tokens[:]
            trajectory_[:, 1:3] = trajectory
            trajectory = trajectory_
            len_future_samples = len(future_samples)
            del trajectory_, trajectory_tokens, past_samples, future_samples, past_samples_local, future_samples_local

            curr_sample = self.helper.get_past_for_agent(
                instance_token, sample_token, 1000, False,
                False)[:MAX_TRAJ_LEN][-1]

            for i in range(trajectory.shape[0]):
                # instance_id, sample_id, x, y, velocity, acc, yaw
                velocity = self.helper.get_velocity_for_agent(
                    instance_token, curr_sample["sample_token"])
                acceleration = self.helper.get_acceleration_for_agent(
                    instance_token, curr_sample["sample_token"])
                heading_change_rate = self.helper.get_heading_change_rate_for_agent(
                    instance_token, curr_sample["sample_token"])

                if math.isnan(velocity):
                    velocity = 0
                if math.isnan(acceleration):
                    acceleration = 0
                if math.isnan(heading_change_rate):
                    heading_change_rate = 0

                # need to check paper for relative velocity? same for acc and yaw
                trajectory[i][3:] = [
                    velocity, acceleration, heading_change_rate
                ]
                # if curr_sample['next'] == '':
                #     import pdb
                #     pdb.set_trace()

                # No need to get next sample token in case this is last element in the series
                # prevents bug
                if i < trajectory.shape[0] - 1:
                    next_sample_token = self.nuscenes.get(
                        'sample_annotation',
                        curr_sample['next'])['sample_token']
                    curr_sample = self.helper.get_sample_annotation(
                        instance_token, next_sample_token)

            s = str(instance_id) + ","
            # assert (MAX_TRAJ_LEN+len_future_samples) >= trajectory.shape[0]
            repeat = (MAX_TRAJ_LEN + len_future_samples) - trajectory.shape[0]
            leading_arr = np.array(
                repeat * [-1, -64, -64, -64, -64, -64]).reshape((repeat, 6))
            trajectory = np.append(leading_arr, trajectory, axis=0)

            # print("Built In!")
            # self.nuim.render_trajectory(sample_token, rotation_yaw=0, center_key_pose=True)
            # print("Bassel's!")
            # visualize_traffic(trajectory[(trajectory != [-1, -64, -64, -64, -64, -64]).all(axis=1), 1:3].copy())

            for i in range(trajectory.shape[0]):
                sample_id, x, y, velocity, acceleration, heading_change_rate = trajectory[
                    i]
                s += str(sample_id) + "," + str(x) + "," + str(y) + "," + str(velocity) + "," \
                     + str(acceleration) + "," + str(heading_change_rate)
                if i != trajectory.shape[0] - 1:
                    s += ","
                else:
                    s += "\n"

            samples_new_format.append(s)

        # print("max past trajectory len:",max_past_traj_len)

        # samples_new_format.sort(key=lambda x: int(x.split(",")[0]))

        with open(out_file, 'w') as fw:
            fw.writelines(samples_new_format)

        print(out_file + "size " + str(ds_size))

    def get_format_mha_jam_context(self, states_filepath, out_file):
        with open(states_filepath) as fr:
            agents_states = fr.readlines()

        # format
        # agent_id, 20x(frame_id, x, y, v, a, yaw_rate)]
        agents_states = [[float(x.rstrip()) for x in s.split(',')]
                         for s in agents_states]

        mode = "train" if out_file.find("_train") != -1 else "val"
        mini = "mini" if out_file.find("mini") != -1 else "main"

        with open("dicts_sample_and_instances_id2token_" + mode + "_" + mini +
                  ".json") as fr:
            instance_dict_id_token, sample_dict_id_token = json.load(fr)

        # Get Context for each sample in states
        context = []
        agent_ind = 0

        for agent in tqdm(agents_states):
            instance_token = instance_dict_id_token[str(int(agent[0]))]
            mid_frame_id = int(agent[1 + 6 * (MAX_TRAJ_LEN - 1)])
            sample_token = sample_dict_id_token[str(mid_frame_id)]
            frame_annotations = self.helper.get_annotations_for_sample(
                sample_token)
            surroundings_agents_coords = []
            surroundings_agents_instance_token = []

            for ann in frame_annotations:
                if ann['category_name'].find("vehicle") == -1:
                    continue
                if ann['instance_token'] == instance_token:
                    agent_ann = ann
                else:
                    surroundings_agents_coords.append(ann["translation"][:2])
                    surroundings_agents_instance_token.append(
                        ann["instance_token"])

            if len(surroundings_agents_coords) != 0:
                surroundings_agents_coords = convert_global_coords_to_local(
                    surroundings_agents_coords, agent_ann["translation"],
                    agent_ann["rotation"])

            # for i in range(len(surroundings_agents_coords)):
            #     if surroundings_agents_coords[i][0] < -25 or surroundings_agents_coords[i][0] > 25 \
            #             or surroundings_agents_coords[i][1] < -10 or surroundings_agents_coords[i][1] > 40:
            #         surroundings_agents_coords[i] = None
            #         surroundings_agents_instance_token[i] = None

            total_area_side = 50
            cell_size = 1.5625
            map_side_size = int(total_area_side // cell_size)

            map = [[[-64, -64, -64, -64, -64] for i in range(MAX_TRAJ_LEN)]
                   for j in range(map_side_size * map_side_size)]

            for n in range(len(surroundings_agents_coords)):
                # if np.isnan(surroundings_agents_coords[n][0]): # ---> surroundings_agents_coords[n] is None
                #     continue
                # search for the agent location in the map
                # agent_found = False
                # for i in range(map_side_size):
                #     for j in range(map_side_size):
                #         # if agent found in the cell
                #         if surroundings_agents_coords[n][0] >= (j * cell_size) - 25\
                #                 and surroundings_agents_coords[n][0] < (j * cell_size) - 25 + cell_size \
                #                 and surroundings_agents_coords[n][1] < 40 - (i * cell_size) \
                #                 and surroundings_agents_coords[n][1] > 40 - (i * cell_size + cell_size):
                #             found_i, found_j = i, j
                #             break

                # get the agent location in the map!
                alpha_y = (surroundings_agents_coords[n][1] - (-10)) / (40 -
                                                                        (-10))
                i = (map_side_size - 1) - int(alpha_y * map_side_size + 0)

                alpha_x = (surroundings_agents_coords[n][0] - (-25)) / (25 -
                                                                        (-25))
                j = int(alpha_x * map_side_size + 0)

                # Confirmation the 2 methods yield the same results
                # if not(found_i == i and found_j == j):
                #     raise Exception("Calculations error")

                # prevent out of bound cases (which shall never happen if none is set for out of bound (line 240)
                if not (i >= 0 and i < map_side_size and j >= 0
                        and j < map_side_size):
                    # raise Exception("Calculations error")
                    continue

                pos = i * map_side_size + j

                past_trajectory = self.get_current_past_trajectory(
                    surroundings_agents_instance_token[n],
                    sample_token,
                    num_seconds=1000)[:MAX_TRAJ_LEN]
                assert len(past_trajectory) <= MAX_TRAJ_LEN
                retrieved_trajectory_len = len(past_trajectory)

                if map[pos][-1][0] != -64:
                    skip_traj = False
                    # Save the trajectory with greater length
                    for ind, map_pos in enumerate(map[pos]):
                        if map_pos[0] != 64:
                            if MAX_TRAJ_LEN - ind > retrieved_trajectory_len:
                                skip_traj = True
                    if skip_traj:
                        agent_found = True
                        break
                    else:
                        # print("new longer agent trajectory in cell")
                        pass

                past_trajectory = convert_global_coords_to_local(
                    past_trajectory, agent_ann["translation"],
                    agent_ann["rotation"])

                if retrieved_trajectory_len != MAX_TRAJ_LEN:
                    past_trajectory = np.concatenate([
                        np.array([[-64, -64]
                                  for _ in range(MAX_TRAJ_LEN -
                                                 past_trajectory.shape[0])]),
                        past_trajectory
                    ],
                                                     axis=0)

                neighbour_agent_features = []

                skip_traj = False

                for k in range(0, MAX_TRAJ_LEN):
                    if retrieved_trajectory_len > k:
                        if k == 0:
                            sample_token_i = sample_dict_id_token[str(
                                mid_frame_id)]
                        else:
                            sample_token_i = self.helper.get_sample_annotation(
                                surroundings_agents_instance_token[n],
                                sample_token_i)["prev"]
                            sample_token_i = self.nuscenes.get(
                                'sample_annotation',
                                sample_token_i)['sample_token']
                        try:
                            velocity = self.helper.get_velocity_for_agent(
                                surroundings_agents_instance_token[n],
                                sample_token_i)
                        except:
                            skip_traj = True
                            # print("error")
                            break
                        acceleration = self.helper.get_acceleration_for_agent(
                            surroundings_agents_instance_token[n],
                            sample_token_i)
                        heading_change_rate = self.helper.get_heading_change_rate_for_agent(
                            surroundings_agents_instance_token[n],
                            sample_token_i)
                        if math.isnan(velocity):
                            velocity = 0
                        if math.isnan(acceleration):
                            acceleration = 0
                        if math.isnan(heading_change_rate):
                            heading_change_rate = 0

                        neighbour_agent_features.append(
                            [velocity, acceleration, heading_change_rate])
                    else:
                        neighbour_agent_features.append([-64, -64, -64])

                if skip_traj:
                    print("skipping agent because it has missing data")
                    agent_found = True
                    break

                past_trajectory = np.concatenate(
                    [past_trajectory, neighbour_agent_features], axis=1)
                map[pos] = past_trajectory.tolist()
                # agent_found = True
                # break
                #     if agent_found:
                #         break

            map = np.array(map).astype(np.float16)

            if VISUALIZE_DATA:
                visualize_traffic_neighbours(map,
                                             map_side_size * map_side_size)

            # context.append(map)
            if not os.path.exists(os.path.dirname(out_file)):
                os.makedirs(os.path.dirname(out_file))

            np.save(out_file.replace("_.txt", "__" + str(agent_ind) + ".txt"),
                    map)
            agent_ind += 1

            # with open(out_file, 'ab') as fw:
            #     pickle.dump(map, fw)
            #     continue
            # fw.write(map)

    def get_current_past_trajectory(self,
                                    instance_token,
                                    sample_token,
                                    num_seconds,
                                    just_xy=True,
                                    in_agent_frame=False):
        past_samples = self.helper.get_past_for_agent(
            instance_token, sample_token, num_seconds, in_agent_frame,
            just_xy)[::-1]  #[0:7][::-1]
        current_sample = self.helper.get_sample_annotation(
            instance_token, sample_token)

        if just_xy:
            current_sample = current_sample["translation"][:2]
            if past_samples.shape[0] == 0:
                trajectory = np.array([current_sample])
            else:
                trajectory = np.append(past_samples, [current_sample], axis=0)
        else:
            trajectory = np.append(past_samples, [current_sample], axis=0)
        return trajectory

    def get_format_mha_jam_maps(self, states_filepath, out_file):
        with open(states_filepath) as fr:
            agents_states = fr.readlines()

        # format
        # agen t_id, 20x(frame_id, x, y, v, a, yaw_rate)]
        agents_states = [[float(x.rstrip()) for x in s.split(',')]
                         for s in agents_states]

        mode = "train" if out_file.find("_train") != -1 else "val"
        mini = "mini" if out_file.find("mini") != -1 else "main"

        with open("dicts_sample_and_instances_id2token_" + mode + "_" + mini +
                  ".json") as fr:
            instance_dict_id_token, sample_dict_id_token = json.load(fr)

        # Get map for each sample in states
        agent_ind = 0
        static_layer_rasterizer = StaticLayerRasterizer(self.helper)
        agent_rasterizer = AgentBoxesWithFadedHistory(self.helper,
                                                      seconds_of_history=1)
        mtp_input_representation = InputRepresentation(static_layer_rasterizer,
                                                       agent_rasterizer,
                                                       Rasterizer())

        if not os.path.exists(os.path.dirname(out_file)):
            os.makedirs(os.path.dirname(out_file))

        for agent in tqdm(agents_states):
            instance_token = instance_dict_id_token[str(int(agent[0]))]
            mid_frame_id = int(agent[1 + 6 * (MAX_TRAJ_LEN)])
            sample_token = sample_dict_id_token[str(mid_frame_id)]
            img = mtp_input_representation.make_input_representation(
                instance_token, sample_token)
            # img = cv2.resize(img, (1024, 1024))
            cv2.imwrite(
                out_file.replace("_.jpg", "__" + str(agent_ind) + ".jpg"), img)
            agent_ind += 1

    def run(self, out_dir):
        if self.dataset_version.find("mini") != -1:
            train_agents = get_prediction_challenge_split(
                "mini_train", dataroot=self.DATAROOT)
            val_agents = get_prediction_challenge_split("mini_val",
                                                        dataroot=self.DATAROOT)
        else:
            train_agents = get_prediction_challenge_split(
                "train", dataroot=self.DATAROOT)
            train_agents.extend(
                get_prediction_challenge_split("train_val",
                                               dataroot=self.DATAROOT))
            val_agents = get_prediction_challenge_split("val",
                                                        dataroot=self.DATAROOT)

        ## Statistics
        # mx =-1
        # for  in train_agents:
        #     instance_token, sample_token = current_sample.split("_")
        #     past_samples_local = self.helper.get_past_for_agent(instance_token, sample_token, 100, True, True)[::-1]
        #     if len(past_samples_local) > mx:
        #         mx = len(past_samples_local)
        # print("max length of the past sequences for trainval is:",mx)
        # for instance_token, sample_token in train_agents:
        #     past_samples_local = self.helper.get_past_for_agent(instance_token, sample_token, 100, True, True)[::-1]
        #     if len(past_samples_local) > mx:
        #         mx = len(past_samples_local)
        # print("max length of the past sequence for val is:",mx)
        # return

        self.get_format_mha_jam(
            train_agents,
            os.path.join(out_dir,
                         "states_train_" + self.dataset_version + ".txt"))
        self.get_format_mha_jam_context(
            os.path.join(out_dir,
                         "states_train_" + self.dataset_version + ".txt"),
            os.path.join(out_dir, "context_train_" + self.dataset_version,
                         "context_train_.txt"))
        self.get_format_mha_jam_maps(
            os.path.join(out_dir,
                         "states_train_" + self.dataset_version + ".txt"),
            os.path.join(out_dir, "maps_train_" + self.dataset_version,
                         "maps_train_.jpg"))
        # 25
        self.get_format_mha_jam(
            val_agents,
            os.path.join(out_dir,
                         "states_val_" + self.dataset_version + ".txt"))
        self.get_format_mha_jam_context(
            os.path.join(out_dir,
                         "states_val_" + self.dataset_version + ".txt"),
            os.path.join(out_dir, "context_val_" + self.dataset_version,
                         "context_val_.txt"))
        self.get_format_mha_jam_maps(
            os.path.join(out_dir,
                         "states_val_" + self.dataset_version + ".txt"),
            os.path.join(out_dir, "maps_val_" + self.dataset_version,
                         "maps_val_.jpg"))
Пример #11
0
class NS(Dataset):
    def __init__(self,
                 dataroot: str,
                 split: str,
                 t_h: float = 2,
                 t_f: float = 6,
                 grid_dim: int = 25,
                 img_size: int = 200,
                 horizon: int = 40,
                 grid_extent: Tuple[int, int, int, int] = (-25, 25, -10, 40),
                 num_actions: int = 4,
                 image_extraction_mode: bool = False):
        """
        Initializes dataset class for nuScenes prediction

        :param dataroot: Path to tables and data
        :param split: Dataset split for prediction benchmark ('train'/'train_val'/'val')
        :param t_h: Track history in seconds
        :param t_f: Prediction horizon in seconds
        :param grid_dim: Size of grid, default: 25x25
        :param img_size: Size of raster map image in pixels, default: 200x200
        :param horizon: MDP horizon
        :param grid_extent: Map extents in meters, (-left, right, -behind, front)
        :param num_actions: Number of actions for each state (4: [D,R,U,L] or 8: [D, R, U, L, DR, UR, DL, UL])
        :param image_extraction_mode: Whether dataset class is being used for image extraction
        """

        # Nuscenes dataset and predict helper
        self.dataroot = dataroot
        self.ns = NuScenes('v1.0-trainval', dataroot=dataroot)
        self.helper = PredictHelper(self.ns)
        self.token_list = get_prediction_challenge_split(split,
                                                         dataroot=dataroot)

        # Useful parameters
        self.grid_dim = grid_dim
        self.grid_extent = grid_extent
        self.img_size = img_size
        self.t_f = t_f
        self.t_h = t_h
        self.horizon = horizon
        self.num_actions = num_actions

        # Map row, column and velocity states to actual values
        grid_size_m = self.grid_extent[1] - self.grid_extent[0]
        self.row_centers = np.linspace(
            self.grid_extent[3] - grid_size_m / (self.grid_dim * 2),
            self.grid_extent[2] + grid_size_m / (self.grid_dim * 2),
            self.grid_dim)

        self.col_centers = np.linspace(
            self.grid_extent[0] + grid_size_m / (self.grid_dim * 2),
            self.grid_extent[1] - grid_size_m / (self.grid_dim * 2),
            self.grid_dim)

        # Surrounding agent input representation: populate grid with velocity, acc, yaw-rate
        self.agent_ip = AgentMotionStatesOnGrid(self.helper,
                                                resolution=grid_size_m /
                                                img_size,
                                                meters_ahead=grid_extent[3],
                                                meters_behind=-grid_extent[2],
                                                meters_left=-grid_extent[0],
                                                meters_right=grid_extent[1])

        # Image extraction mode is used for extracting map images offline prior to training
        self.image_extraction_mode = image_extraction_mode
        if self.image_extraction_mode:

            # Raster map representation
            self.map_ip = StaticLayerRasterizer(self.helper,
                                                resolution=grid_size_m /
                                                img_size,
                                                meters_ahead=grid_extent[3],
                                                meters_behind=-grid_extent[2],
                                                meters_left=-grid_extent[0],
                                                meters_right=grid_extent[1])

            # Raster map with agent boxes. Only used for visualization
            static_layer_rasterizer = StaticLayerRasterizer(
                self.helper,
                resolution=grid_size_m / img_size,
                meters_ahead=grid_extent[3],
                meters_behind=-grid_extent[2],
                meters_left=-grid_extent[0],
                meters_right=grid_extent[1])

            agent_rasterizer = AgentBoxesWithFadedHistory(
                self.helper,
                seconds_of_history=1,
                resolution=grid_size_m / img_size,
                meters_ahead=grid_extent[3],
                meters_behind=-grid_extent[2],
                meters_left=-grid_extent[0],
                meters_right=grid_extent[1])

            self.map_ip_agents = InputRepresentation(static_layer_rasterizer,
                                                     agent_rasterizer,
                                                     Rasterizer())

    def __len__(self):
        return len(self.token_list)

    def __getitem__(self, idx):
        """
        Returns inputs, ground truth values and other utilities for data point at given index

        :return hist: snippet of track history, default 2s at 0.5 Hz sampling frequency
        :return fut: ground truth future trajectory, default 6s at 0.5 Hz sampling frequency
        :return img: Imagenet normalized bird's eye view map around the target vehicle
        :return svf_e: Goal and path state visitation frequencies for expert demonstration, ie. path from train set
        :return motion_feats: motion and position features used for reward model
        :return waypts_e: (x,y) BEV co-ordinates corresponding to grid cells of svf_e
        :return agents: tensor of surrounding agent states populated in grid around target agent
        :return grid_idcs: grid co-ordinates of svf_e
        :return bc_targets: ground truth actions for training behavior cloning model
        :return img_agents: image with agent boxes for visualization / debugging
        :return instance_token: nuScenes instance token for prediction instance
        :return sample_token: nuScenes sample token for prediction instance
        :return idx: instance id (mainly for debugging)
        """

        # Nuscenes instance and sample token for prediction data point
        instance_token, sample_token = self.token_list[idx].split("_")

        # If dataset is being used for image extraction
        grid_size_m = self.grid_extent[1] - self.grid_extent[0]
        if self.image_extraction_mode:

            # Make directory to store raster map images
            img_dir = os.path.join(
                self.dataroot, 'prediction_raster_maps', 'images' +
                str(self.img_size) + "_" + str(int(grid_size_m)) + 'm')
            if not os.path.isdir(img_dir):
                os.mkdir(img_dir)

            # Generate and save raster map image with just static elements
            img = self.map_ip.make_representation(instance_token, sample_token)
            img_save = Image.fromarray(img)
            img_save.save(
                os.path.join(img_dir,
                             instance_token + "_" + sample_token + '.png'))

            # Generate and save raster map image with static elements and agent boxes (for visualization only)
            img_agents = self.map_ip_agents.make_input_representation(
                instance_token, sample_token)
            img_agents_save = Image.fromarray(img_agents)
            img_agents_save.save(
                os.path.join(
                    img_dir,
                    instance_token + "_" + sample_token + 'agents.png'))

            # Return dummy values
            return 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

        # If dataset is being used for training/validation/testing
        else:

            # Get track history for agent:
            hist = self.get_hist(instance_token, sample_token)
            hist = torch.from_numpy(hist)

            # Get ground truth future for agent:
            fut = self.helper.get_future_for_agent(instance_token,
                                                   sample_token,
                                                   seconds=self.t_f,
                                                   in_agent_frame=True)
            fut = torch.from_numpy(fut)

            # Get indefinite future for computing expert State visitation frequencies (SVF):
            fut_indefinite = self.helper.get_future_for_agent(
                instance_token, sample_token, seconds=300, in_agent_frame=True)

            # Up sample indefinite future by a factor of 10
            fut_interpolated = np.zeros((fut_indefinite.shape[0] * 10 + 1, 2))
            param_query = np.linspace(0, fut_indefinite.shape[0],
                                      fut_indefinite.shape[0] * 10 + 1)
            param_given = np.linspace(0, fut_indefinite.shape[0],
                                      fut_indefinite.shape[0] + 1)
            val_given_x = np.concatenate(([0], fut_indefinite[:, 0]))
            val_given_y = np.concatenate(([0], fut_indefinite[:, 1]))
            fut_interpolated[:, 0] = np.interp(param_query, param_given,
                                               val_given_x)
            fut_interpolated[:, 1] = np.interp(param_query, param_given,
                                               val_given_y)

            # Read pre-extracted raster map image
            img_dir = os.path.join(
                self.dataroot, 'prediction_raster_maps', 'images' +
                str(self.img_size) + "_" + str(int(grid_size_m)) + 'm')
            img = cv2.imread(
                os.path.join(img_dir,
                             instance_token + "_" + sample_token + '.png'))

            # Pre-process image
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = torch.from_numpy(img)
            img = img.permute((2, 0, 1)).float() / 255

            # Normalize using Imagenet stats
            img = normalize_imagenet(img)

            # Read pre-extracted raster map with agent boxes (for visualization + debugging)
            img_agents = cv2.imread(
                os.path.join(
                    img_dir,
                    instance_token + "_" + sample_token + 'agents.png'))

            # Pre-process image
            img_agents = cv2.cvtColor(img_agents, cv2.COLOR_BGR2RGB)
            img_agents = torch.from_numpy(img_agents)
            img_agents = img_agents.permute((2, 0, 1)).float() / 255

            # Get surrounding agent states
            agents = torch.from_numpy(
                self.agent_ip.make_representation(instance_token,
                                                  sample_token))
            agents = agents.permute((2, 0, 1)).float()

            # Sum pool states to down-sample to grid dimensions
            agents = f.avg_pool2d(agents[None, :, :, :],
                                  self.img_size // self.grid_dim)
            agents = agents.squeeze(dim=0) * (
                (self.img_size // self.grid_dim)**2)

            # Get expert SVF:
            svf_e, waypts_e, grid_idcs = self.get_expert_waypoints(
                fut_interpolated)
            svf_e = torch.from_numpy(svf_e)
            waypts_e = torch.from_numpy(waypts_e)
            grid_idcs = torch.from_numpy(grid_idcs)

            # Get motion and position feats:
            motion_feats = self.get_motion_feats(instance_token, sample_token)
            motion_feats = torch.from_numpy(motion_feats)

            # Targets for behavior cloning model:
            bc_targets = self.get_bc_targets(fut_interpolated)
            bc_targets = torch.from_numpy(bc_targets)

            return hist, fut, img, svf_e, motion_feats, waypts_e, agents, grid_idcs, bc_targets, img_agents, \
                instance_token, sample_token, idx

    def get_hist(self, instance_token: str, sample_token: str):
        """
        Function to get track history of agent
        :param instance_token: nuScenes instance token for datapoint
        :param sample_token nuScenes sample token for datapoint
        """
        # x, y co-ordinates in agent's frame of reference
        xy = self.helper.get_past_for_agent(instance_token,
                                            sample_token,
                                            seconds=self.t_h,
                                            in_agent_frame=True)

        # Get all history records for obtaining velocity, acceleration and turn rate values
        hist_records = self.helper.get_past_for_agent(instance_token,
                                                      sample_token,
                                                      seconds=self.t_h,
                                                      in_agent_frame=True,
                                                      just_xy=False)
        if xy.shape[0] > self.t_h * 2:
            xy = xy[0:int(self.t_h) * 2]
        if len(hist_records) > self.t_h * 2:
            hist_records = hist_records[0:int(self.t_h) * 2]

        # Initialize hist tensor and set x and y co-ordinates returned by prediction helper
        hist = np.zeros((xy.shape[0], 5))
        hist[:, 0:2] = xy

        # Instance and sample tokens from history records
        i_tokens = [
            hist_records[i]['instance_token'] for i in range(len(hist_records))
        ]
        i_tokens.insert(0, instance_token)
        s_tokens = [
            hist_records[i]['sample_token'] for i in range(len(hist_records))
        ]
        s_tokens.insert(0, sample_token)

        # Set velocity, acc and turn rate values for hist
        for k in range(hist.shape[0]):
            i_t = i_tokens[k]
            s_t = s_tokens[k]
            v = self.helper.get_velocity_for_agent(i_t, s_t)
            a = self.helper.get_acceleration_for_agent(i_t, s_t)
            theta = self.helper.get_heading_change_rate_for_agent(i_t, s_t)

            # If function returns nan values due to short tracks, set corresponding value to 0
            if np.isnan(v):
                v = 0
            if np.isnan(a):
                a = 0
            if np.isnan(theta):
                theta = 0
            hist[k, 2] = v
            hist[k, 3] = a
            hist[k, 4] = theta

        # Zero pad for track histories shorter than t_h
        hist_zeropadded = np.zeros((int(self.t_h) * 2, 5))

        # Flip to have correct order of timestamps
        hist = np.flip(hist, 0)
        hist_zeropadded[-hist.shape[0]:] = hist

        return hist_zeropadded

    def get_expert_waypoints(self, fut: np.ndarray):
        """
        Function to get the expert's state visitation frequencies based on their trajectory
        :param fut: numpy array with future trajectory of for all available future timestamps, up-sampled by 10
        """

        # Expert state visitation frequencies for training reward model, waypoints in meters and grid indices
        svf_e = np.zeros((2, self.grid_dim, self.grid_dim))
        waypts_e = np.zeros((self.horizon, 2))
        grid_idcs = np.zeros((self.horizon, 2))

        count = 0
        row_prev = np.nan
        column_prev = np.nan
        for k in range(fut.shape[0]):

            # Convert trajectory (x,y) co-ordinates to grid locations:
            column = np.argmin(np.absolute(fut[k, 0] - self.col_centers))
            row = np.argmin(np.absolute(fut[k, 1] - self.row_centers))

            # Demonstration ends when expert leaves the image crop corresponding to the grid:
            if self.grid_extent[0] <= fut[k, 0] <= self.grid_extent[1] and \
                    self.grid_extent[2] <= fut[k, 1] <= self.grid_extent[3]:

                # Check if cell location has changed
                if row != row_prev or column != column_prev:

                    # Add cell location to path states of expert
                    svf_e[0, row.astype(int), column.astype(int)] = 1

                    if count < self.horizon:

                        # Get BEV coordinates corresponding to cell locations
                        waypts_e[count, 0] = self.row_centers[row]
                        waypts_e[count, 1] = self.col_centers[column]
                        grid_idcs[count, 0] = row
                        grid_idcs[count, 1] = column
                        count += 1
            else:
                break
            column_prev = column
            row_prev = row

        # Last cell location where demonstration terminates is the goal state:
        svf_e[1, row_prev.astype(int), column_prev.astype(int)] = 1

        return svf_e, waypts_e, grid_idcs

    def get_motion_feats(self, instance_token: str, sample_token: str):
        """
        Function to get motion and position features over grid for reward model
        :param instance_token: NuScenes instance token for datapoint
        :param sample_token: NuScenes sample token for datapoint
        """
        feats = np.zeros((3, self.grid_dim, self.grid_dim))

        # X and Y co-ordinates over grid
        grid_size_m = self.grid_extent[1] - self.grid_extent[0]
        y = (np.linspace(
            self.grid_extent[3] - grid_size_m / (self.grid_dim * 2),
            self.grid_extent[2] + grid_size_m / (self.grid_dim * 2),
            self.grid_dim)).reshape(-1, 1).repeat(self.grid_dim, axis=1)
        x = (np.linspace(
            self.grid_extent[0] + grid_size_m / (self.grid_dim * 2),
            self.grid_extent[1] - grid_size_m / (self.grid_dim * 2),
            self.grid_dim)).reshape(-1, 1).repeat(self.grid_dim,
                                                  axis=1).transpose()

        # Velocity of agent
        v = self.helper.get_velocity_for_agent(instance_token, sample_token)
        if np.isnan(v):
            v = 0

        # Normalize X and Y co-ordinates over grid
        feats[0] = v
        feats[1] = x / grid_size_m
        feats[2] = y / grid_size_m

        return feats

    def get_bc_targets(self, fut: np.ndarray):
        """
        Function to get targets for behavior cloning model
        :param fut: numpy array with future trajectory of for all available future timestamps, up-sampled by 10
        """
        bc_targets = np.zeros(
            (self.num_actions + 1, self.grid_dim, self.grid_dim))
        column_prev = np.argmin(np.absolute(fut[0, 0] - self.col_centers))
        row_prev = np.argmin(np.absolute(fut[0, 1] - self.row_centers))

        for k in range(fut.shape[0]):

            # Convert trajectory (x,y) co-ordinates to grid locations:
            column = np.argmin(np.absolute(fut[k, 0] - self.col_centers))
            row = np.argmin(np.absolute(fut[k, 1] - self.row_centers))

            # Demonstration ends when expert leaves the image crop corresponding to the grid:
            if self.grid_extent[0] <= fut[k, 0] <= self.grid_extent[1] and self.grid_extent[2] <= fut[k, 1] <= \
                    self.grid_extent[3]:

                # Check if cell location has changed
                if row != row_prev or column != column_prev:
                    bc_targets[:, int(row_prev), int(column_prev)] = 0
                    d_x = column - column_prev
                    d_y = row - row_prev
                    theta = np.arctan2(d_y, d_x)

                    # Assign ground truth actions for expert demonstration
                    if self.num_actions == 4:  # [D,R,U,L,end]
                        if np.pi / 4 <= theta < 3 * np.pi / 4:
                            bc_targets[0, int(row_prev), int(column_prev)] = 1
                        elif -np.pi / 4 <= theta < np.pi / 4:
                            bc_targets[1, int(row_prev), int(column_prev)] = 1
                        elif -3 * np.pi / 4 <= theta < -np.pi / 4:
                            bc_targets[2, int(row_prev), int(column_prev)] = 1
                        else:
                            bc_targets[3, int(row_prev), int(column_prev)] = 1

                    else:  # [D, R, U, L, DR, UR, DL, UL, end]
                        if 3 * np.pi / 8 <= theta < 5 * np.pi / 8:
                            bc_targets[0, int(row_prev), int(column_prev)] = 1
                        elif -np.pi / 8 <= theta < np.pi / 8:
                            bc_targets[1, int(row_prev), int(column_prev)] = 1
                        elif -5 * np.pi / 8 <= theta < -3 * np.pi / 8:
                            bc_targets[2, int(row_prev), int(column_prev)] = 1
                        elif np.pi / 8 <= theta < 3 * np.pi / 8:
                            bc_targets[4, int(row_prev), int(column_prev)] = 1
                        elif -3 * np.pi / 8 <= theta < -np.pi / 8:
                            bc_targets[5, int(row_prev), int(column_prev)] = 1
                        elif 5 * np.pi / 8 <= theta < 7 * np.pi / 8:
                            bc_targets[6, int(row_prev), int(column_prev)] = 1
                        elif -7 * np.pi / 8 <= theta < -5 * np.pi / 8:
                            bc_targets[7, int(row_prev), int(column_prev)] = 1
                        else:
                            bc_targets[3, int(row_prev), int(column_prev)] = 1
            else:
                break
            column_prev = column
            row_prev = row

        # Final action is the end action to transition to the goal state:
        bc_targets[self.num_actions, int(row_prev), int(column_prev)] = 1

        return bc_targets
class NuScenesFormatTransformer:
    def __init__(self,
                 DATAROOT='./data/sets/nuscenes',
                 dataset_version='v1.0-mini'):
        self.DATAROOT = DATAROOT
        self.dataset_version = dataset_version
        self.nuscenes = NuScenes(dataset_version, dataroot=self.DATAROOT)
        self.helper = PredictHelper(self.nuscenes)
        # ['vehicle.car', 'vehicle.truck', 'vehicle.bus.rigid', 'vehicle.bus.bendy', 'vehicle.construction']
        self.category_token_to_id = {
            "fd69059b62a3469fbaef25340c0eab7f": 1,  # 'vehicle.car'
            "6021b5187b924d64be64a702e5570edf": 1,  # 'vehicle.truck'
            "fedb11688db84088883945752e480c2c": 2,  # 'vehicle.bus.rigid'
            "003edbfb9ca849ee8a7496e9af3025d4": 2,  # 'vehicle.bus.bendy'
            "5b3cd6f2bca64b83aa3d0008df87d0e4": 3,  # 'vehicle.construction'
            "7b2ff083a64e4d53809ae5d9be563504": 1
        }  # vehicle.emergency.police

    def get_new_format(self,
                       samples_agents,
                       format_for_model,
                       out_file=("./transformer_format.txt"),
                       num_seconds=None):
        # for current_sample in samples_agents:
        #     instance_token, sample_token = current_sample.split("_")
        #     traj = self.helper.get_future_for_agent(instance_token, sample_token, 6, True)
        #     past_traj = self.helper.get_past_for_agent(instance_token, sample_token, 6, True)
        #
        #     if len(past_traj) + len(traj) + 1 < 20:
        #         print(len(past_traj) + len(traj) + 1)
        #
        # exit()
        ####################
        # Sample Token (frame) to a sequential id
        # for each sample (agent_frame), get the scene it belongs to, and then get the first sample (frame)
        # loop on all samples from the first sample till the end
        # set to dictionary the sequential id for each sample
        splitting_format = '\t'

        if format_for_model.value == FORMAT_FOR_MODEL.TRAFFIC_PREDICT.value:
            splitting_format = " "

        instance_token_to_id_dict = {}
        sample_token_to_id_dict = {}
        scene_token_dict = {}
        sample_id = 0
        instance_id = 0

        for current_sample in samples_agents:
            instance_token, sample_token = current_sample.split("_")
            scene_token = self.nuscenes.get('sample',
                                            sample_token)["scene_token"]

            if scene_token in scene_token_dict:
                continue

            # get the first sample in this sequence
            scene_token_dict[scene_token] = True
            first_sample_token = self.nuscenes.get(
                "scene", scene_token)["first_sample_token"]
            current_sample = self.nuscenes.get('sample', first_sample_token)

            while True:
                if current_sample['token'] not in sample_token_to_id_dict:
                    sample_token_to_id_dict[
                        current_sample['token']] = sample_id
                    sample_id += 1
                else:
                    print("should not happen?")

                instances_in_sample = self.helper.get_annotations_for_sample(
                    current_sample['token'])

                for sample_instance in instances_in_sample:
                    if sample_instance[
                            'instance_token'] not in instance_token_to_id_dict:
                        instance_token_to_id_dict[
                            sample_instance['instance_token']] = instance_id
                        instance_id += 1

                if current_sample['next'] == "":
                    break

                current_sample = self.nuscenes.get('sample',
                                                   current_sample['next'])

        #############
        # Converting to the transformer network format
        # frame_id, agent_id, pos_x, pos_y
        # todo:
        # loop on all the agents, if agent not taken:
        # 1- add it to takens agents (do not retake the agent)
        # 2- get the number of appearance of this agent
        # 3- skip this agent if the number is less than 10s (4 + 6)
        # 4- get the middle agent's token
        # 5- get the past and future agent's locations relative to its location
        samples_new_format = []
        taken_instances = {}
        ds_size = 0

        for current_sample in samples_agents:
            instance_token, sample_token = current_sample.split("_")
            instance_id, sample_id = instance_token_to_id_dict[
                instance_token], sample_token_to_id_dict[sample_token]

            if instance_id in taken_instances:
                continue

            taken_instances[instance_id] = True

            trajectory = self.get_trajectory_around_sample(
                instance_token, sample_token)
            trajectory_full_instances = self.get_trajectory_around_sample(
                instance_token, sample_token, just_xy=False)
            # traj_samples_token = [instance['sample_token'] for instance in trajectory_full_instances]

            if len(trajectory) < 20:
                print("length is less than 20 samples, trajectory length is: ",
                      len(trajectory))
                continue

            ds_size += 1

            if num_seconds is not None:
                start, end = len(trajectory) // 2 - 9, len(
                    trajectory) // 2 + 11,
                starting_frame = (start + end) // 2

                middle_sample_token = trajectory_full_instances[
                    starting_frame]["sample_token"]
                trajectory = self.get_trajectory_around_sample(
                    instance_token,
                    middle_sample_token,
                    just_xy=True,
                    num_seconds=num_seconds,
                    in_agent_frame=True)
                trajectory_full_instances = self.get_trajectory_around_sample(
                    instance_token,
                    middle_sample_token,
                    just_xy=False,
                    num_seconds=num_seconds,
                    in_agent_frame=True)
                # traj_samples_token = [instance['sample_token'] for instance in trajectory_full_instances]

            # get_trajectory at this position
            for i in range(trajectory.shape[0]):
                traj_sample, sample_token = trajectory[
                    i], trajectory_full_instances[i]["sample_token"]
                sample_id = sample_token_to_id_dict[sample_token]
                if format_for_model.value == FORMAT_FOR_MODEL.TRANSFORMER_NET.value:
                    yaw = quaternion_yaw(
                        Quaternion(trajectory_full_instances[i]["rotation"]))

                    # samples_new_format.append(str(sample_id) + splitting_format + str(instance_id)\
                    #                           + splitting_format + str(traj_sample[0]) + splitting_format \
                    #                           + str(traj_sample[1]) + splitting_format + str(yaw) + "\n")
                    x, y, z = trajectory_full_instances[i]["translation"]
                    w, l, h = trajectory_full_instances[i]["size"]

                    samples_new_format.append(
                        str(sample_id) + splitting_format +
                        str(instance_id
                            )  #+ splitting_format + str(object_type)\
                        + splitting_format + str(x) + splitting_format +
                        str(y)  #+ splitting_format + str(z)
                        # + splitting_format + str(l) + splitting_format + str(w) + splitting_format + str(h)
                        + splitting_format + str(yaw) + "\n")
                elif format_for_model.value == FORMAT_FOR_MODEL.TRAFFIC_PREDICT.value:
                    # raise Exception("not implemented yet")
                    category_token = self.nuscenes.get(
                        "instance", instance_token)["category_token"]
                    object_type = self.category_token_to_id[category_token]
                    # frame_id, object_id, object_type,
                    # position_x, position_y, position_z,
                    # object_length, object_width, object_height,
                    # heading
                    x, y, z = trajectory_full_instances[i]["translation"]
                    w, l, h = trajectory_full_instances[i]["size"]
                    # yaw = angle_of_rotation(quaternion_yaw(Quaternion(trajectory_full_instances[i]["rotation"])))
                    yaw = quaternion_yaw(
                        Quaternion(trajectory_full_instances[i]["rotation"]))

                    samples_new_format.append(str(sample_id) + splitting_format + str(instance_id) + splitting_format + str(object_type)\
                                              + splitting_format + str(x) + splitting_format + str(y) + splitting_format + str(z) + splitting_format
                                              + splitting_format + str(l) + splitting_format + str(w) + splitting_format + str(h) + splitting_format
                                              + str(yaw) + "\n")
            # annotations = helper.get_annotations_for_sample(sample_token)

            # for ann in annotations:
            #     # q = ann['rotation']
            #     # yaw = math.atan2(2.0 * (q[3] * q[0] + q[1] * q[2]), - 1.0 + 2.0 * (q[0] * q[0] + q[1] * q[1]))*180/math.pi
            #     # if yaw < 0:
            #     #     yaw += 360
            #     # selected_sample_data = [sample_id, instance_id] + ann['translation'] + [yaw] + ann['size']
            #     selected_sample_data = str(sample_id) + " " + str(instance_token_to_id_dict[ann['instance_token']])\
            #                            + " " + str(ann['translation'][0]) + " " + str(ann['translation'][2]) + "\n"
            #     samples_new_format.append(selected_sample_data)

        # no need for sorting as it occurs in the TransformationNet data loader
        # left it for similarity

        samples_new_format.sort(
            key=lambda x: int(x.split(splitting_format)[0]))

        with open(out_file, 'w') as fw:
            fw.writelines(samples_new_format)

        print(out_file + "size " + str(ds_size))

    def run(self, format_for_model):
        if self.dataset_version.find("mini") != -1:
            train_agents = get_prediction_challenge_split(
                "mini_train", dataroot=self.DATAROOT)
            val_agents = get_prediction_challenge_split("mini_val",
                                                        dataroot=self.DATAROOT)
        else:
            train_agents = get_prediction_challenge_split(
                "train", dataroot=self.DATAROOT)
            train_agents.extend(
                get_prediction_challenge_split("train_val",
                                               dataroot=self.DATAROOT))
            val_agents = get_prediction_challenge_split("val",
                                                        dataroot=self.DATAROOT)

        self.get_new_format(
            train_agents, format_for_model,
            "/home/bassel/PycharmProjects/Trajectory-Transformer/datasets/nuscenes/bkup/transformer_train_"
            + self.dataset_version + ".txt")
        self.get_new_format(
            val_agents, format_for_model,
            "/home/bassel/PycharmProjects/Trajectory-Transformer/datasets/nuscenes/bkup/transformer_val_"
            + self.dataset_version + ".txt")
        # shutil.copy("/home/bassel/PycharmProjects/Trajectory-Transformer/datasets/nuscenes/val/transformer_val_"+self.dataset_version+".txt",
        #             "/home/bassel/PycharmProjects/Trajectory-Transformer/datasets/nuscenes/test/transformer_val_"+self.dataset_version+".txt")

    def get_trajectory_around_sample(self,
                                     instance_token,
                                     sample_token,
                                     just_xy=True,
                                     num_seconds=1000,
                                     in_agent_frame=False):
        future_samples = self.helper.get_future_for_agent(
            instance_token, sample_token, num_seconds, in_agent_frame, just_xy)
        past_samples = self.helper.get_past_for_agent(instance_token,
                                                      sample_token,
                                                      num_seconds,
                                                      in_agent_frame,
                                                      just_xy)[::-1]
        current_sample = self.helper.get_sample_annotation(
            instance_token, sample_token)

        if num_seconds == 5:
            if len(past_samples) > 9:
                past_samples = past_samples[0:9]
            if len(future_samples) > 10:
                future_samples = future_samples[0:10]

        if just_xy:
            current_sample = current_sample["translation"][:2]
            trajectory = np.append(past_samples,
                                   np.append([current_sample],
                                             future_samples,
                                             axis=0),
                                   axis=0)
        else:
            trajectory = np.append(past_samples,
                                   np.append([current_sample],
                                             future_samples,
                                             axis=0),
                                   axis=0)
        return trajectory