コード例 #1
0
    def test_included_tracks(self):
        params = ParameterServer()

        map_filename = os.path.join(os.path.dirname(__file__),
                                    "data/DR_DEU_Merging_MT_v01_shifted.xodr")
        track_filename_1 = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_dummy_track.csv")
        track_filename_2 = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_dummy_track_2.csv")

        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "MapFilename"] = map_filename
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "TrackFilenameList"] = [track_filename_1, track_filename_2]
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"]["IncludeTracks"] = {
                track_filename_1: [2],
                track_filename_2: [1, 3],
            }

        def assert_correct_combinations(scenario_generation):
            for scenario, _ in scenario_generation:
                track_filename = scenario.json_params["track_file"]
                ego_id = scenario.eval_agent_ids[0]

                # Check if pair (track_filename, ego_id) is in IncludeTracks
                self.assertIn((track_filename, ego_id),
                              [(track_filename_1, 2), (track_filename_2, 1),
                               (track_filename_2, 3)])

        # CASE 1: num_scenarios < number of scenarios in IncludeTracks
        scenario_generation_1 = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=2)
        # Only two scenarios should be generated (num_scenarios=2)
        self.assertEqual(scenario_generation_1.get_num_scenarios(), 2)
        assert_correct_combinations(scenario_generation_1)

        # CASE 2: num_scenarios > number of scenarios in IncludeTracks
        scenario_generation_2 = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=4)
        # Three scenarios should be generated (all specified in IncludeTracks)
        self.assertEqual(scenario_generation_2.get_num_scenarios(), 3)
        assert_correct_combinations(scenario_generation_2)
コード例 #2
0
    def test_dataset_scenario_generation_full(self):
        params = ParameterServer()

        map_filename =  os.path.join(os.path.dirname(__file__), "data/DR_DEU_Merging_MT_v01_shifted.xodr")
        track_filename =  os.path.join(os.path.dirname(__file__), "data/interaction_dataset_dummy_track.csv")

        params["Scenario"]["Generation"]["InteractionDatasetScenarioGenerationFull"]["MapFilename"] = map_filename
        params["Scenario"]["Generation"]["InteractionDatasetScenarioGenerationFull"]["TrackFilenameList"] = [track_filename]

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=2)

        self.assertEqual(scenario_generation.get_num_scenarios(), 2)
コード例 #3
0
    def test_dataset_scenario_generation_full(self):
        """
    Checking Track file with two track ids, both become valid at time 0. 
    Two Scenarios should be created from this.
    """
        params = ParameterServer()

        map_filename = os.path.join(os.path.dirname(__file__),
                                    "data/DR_DEU_Merging_MT_v01_shifted.xodr")
        track_filename = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_DEU_Merging_dummy_track.csv")

        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "MapFilename"] = map_filename
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "TrackFilenameList"] = [track_filename]
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"]["StartingOffsetMs"] = 0

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=2)

        self.assertEqual(scenario_generation.get_num_scenarios(), 2)

        # first scenario
        scenario0 = scenario_generation.get_scenario(0)
        agent11 = scenario0.GetWorldState().agents[1]
        agent12 = scenario0.GetWorldState().agents[2]
        self.assertEqual(agent11.first_valid_timestamp, 0.0)
        self.assertEqual(agent12.first_valid_timestamp, 0.0)
        # agents are initialized with Behavior::NotValidYet
        self.assertEqual(list(scenario0.GetWorldState().agents_valid.keys()),
                         [])
        scenario0.GetWorldState().Step(0.01)
        self.assertEqual(list(scenario0.GetWorldState().agents_valid.keys()),
                         [1, 2])

        # second scenario
        scenario1 = scenario_generation.get_scenario(1)
        agent21 = scenario1.GetWorldState().agents[1]
        agent22 = scenario1.GetWorldState().agents[2]
        self.assertEqual(agent21.first_valid_timestamp, 0.0)
        self.assertEqual(agent22.first_valid_timestamp, 0.0)
        self.assertEqual(list(scenario1.GetWorldState().agents_valid.keys()),
                         [])
        scenario1.GetWorldState().Step(0.01)
        self.assertEqual(list(scenario1.GetWorldState().agents_valid.keys()),
                         [1, 2])
コード例 #4
0
    def test_dataset_scenario_generation_full_incomplete(self):
        params = ParameterServer()

        map_filename =  os.path.join(os.path.dirname(__file__), "data/DR_CHN_Merging_ZS_partial_v02.xodr")
        track_filename =  os.path.join(os.path.dirname(__file__), "data/interaction_dataset_dummy_track_incomplete.csv")

        params["Scenario"]["Generation"]["InteractionDatasetScenarioGenerationFull"]["MapFilename"] = map_filename
        params["Scenario"]["Generation"]["InteractionDatasetScenarioGenerationFull"]["TrackFilenameList"] = [track_filename]

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=3)
        # agent 1 is not part of the map, so it should only generate 2 scenarios

        self.assertEqual(scenario_generation.get_num_scenarios(), 2)
コード例 #5
0
    def test_dataset_scenario_generation_full_incomplete(self):
        """
    Checking Track file with three track ids, but agent 1 is never inside the map, 
    so it should only generate 2 scenarios
    """
        params = ParameterServer()

        map_filename = os.path.join(os.path.dirname(__file__),
                                    "data/DR_CHN_Merging_ZS_partial_v02.xodr")
        track_filename = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_CHN_Merging_dummy_track_incomplete.csv")

        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "MapFilename"] = map_filename
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "TrackFilenameList"] = [track_filename]
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"]["StartingOffsetMs"] = 0

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=3)
        self.assertEqual(scenario_generation.get_num_scenarios(), 2)

        # first scenario
        agent12 = scenario_generation.get_scenario(0).GetWorldState().agents[2]
        agent17 = scenario_generation.get_scenario(0).GetWorldState().agents[7]
        self.assertEqual(agent12.first_valid_timestamp, 0.0)
        self.assertEqual(agent17.first_valid_timestamp, 0.0)

        # second scenario
        agent22 = scenario_generation.get_scenario(1).GetWorldState().agents[2]
        agent27 = scenario_generation.get_scenario(1).GetWorldState().agents[7]
        self.assertEqual(agent22.first_valid_timestamp, 0.0)
        self.assertEqual(agent27.first_valid_timestamp, 0.0)