def test_dataset_scenario_generation_full_late(self):
        # test wether agent 2 coming in late is correctly identified as invalid at first world time step
        params = ParameterServer()

        map_filename = os.path.join(os.path.dirname(__file__),
                                    "data/DR_DEU_Merging_MT_v01_shifted.xodr")
        track_filename = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_dummy_track_late.csv")

        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "MapFilename"] = map_filename
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "TrackFilenameList"] = [track_filename]
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"]["StartingOffsetMs"] = 0

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=1)

        scenario = scenario_generation.get_scenario(0)
        world_state = scenario.GetWorldState()
        agent1 = world_state.GetAgent(1)
        agent2 = world_state.GetAgent(2)

        self.assertAlmostEqual(agent1.first_valid_timestamp, 0.0)
        self.assertAlmostEqual(agent2.first_valid_timestamp, 0.3)

        self.assertEqual(isinstance(agent1, Agent), True)
        self.assertEqual(agent1.IsValidAtTime(world_state.time), True)

        self.assertEqual(isinstance(agent2, Agent), True)
        self.assertEqual(agent2.IsValidAtTime(world_state.time), False)
    def test_dataset_scenario_generation_full(self):
        params = ParameterServer()

        map_filename =  os.path.join(os.path.dirname(__file__), "data/DR_DEU_Merging_MT_v01_shifted.xodr")
        track_filename =  os.path.join(os.path.dirname(__file__), "data/interaction_dataset_dummy_track.csv")

        params["Scenario"]["Generation"]["InteractionDatasetScenarioGenerationFull"]["MapFilename"] = map_filename
        params["Scenario"]["Generation"]["InteractionDatasetScenarioGenerationFull"]["TrackFilenameList"] = [track_filename]

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=2)

        self.assertEqual(scenario_generation.get_num_scenarios(), 2)
    def test_dataset_scenario_generation_full(self):
        """
    Checking Track file with two track ids, both become valid at time 0. 
    Two Scenarios should be created from this.
    """
        params = ParameterServer()

        map_filename = os.path.join(os.path.dirname(__file__),
                                    "data/DR_DEU_Merging_MT_v01_shifted.xodr")
        track_filename = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_DEU_Merging_dummy_track.csv")

        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "MapFilename"] = map_filename
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "TrackFilenameList"] = [track_filename]
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"]["StartingOffsetMs"] = 0

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=2)

        self.assertEqual(scenario_generation.get_num_scenarios(), 2)

        # first scenario
        scenario0 = scenario_generation.get_scenario(0)
        agent11 = scenario0.GetWorldState().agents[1]
        agent12 = scenario0.GetWorldState().agents[2]
        self.assertEqual(agent11.first_valid_timestamp, 0.0)
        self.assertEqual(agent12.first_valid_timestamp, 0.0)
        # agents are initialized with Behavior::NotValidYet
        self.assertEqual(list(scenario0.GetWorldState().agents_valid.keys()),
                         [])
        scenario0.GetWorldState().Step(0.01)
        self.assertEqual(list(scenario0.GetWorldState().agents_valid.keys()),
                         [1, 2])

        # second scenario
        scenario1 = scenario_generation.get_scenario(1)
        agent21 = scenario1.GetWorldState().agents[1]
        agent22 = scenario1.GetWorldState().agents[2]
        self.assertEqual(agent21.first_valid_timestamp, 0.0)
        self.assertEqual(agent22.first_valid_timestamp, 0.0)
        self.assertEqual(list(scenario1.GetWorldState().agents_valid.keys()),
                         [])
        scenario1.GetWorldState().Step(0.01)
        self.assertEqual(list(scenario1.GetWorldState().agents_valid.keys()),
                         [1, 2])
    def test_dataset_scenario_generation_full_outside3(self):
        """
    Checking Track file with three track ids. test wether agent 3 (the ego agent 
    of the scenario) outside at the beginning is correctly identified as valid at first 
    world time step (although in the track files, it becomes valid later than the other agents 
    -> will be cut off for scenario)
    """
        params = ParameterServer()

        map_filename = os.path.join(os.path.dirname(__file__),
                                    "data/DR_DEU_Merging_MT_v01_shifted.xodr")
        track_filename = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_DEU_Merging_dummy_track_outside.csv")

        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "MapFilename"] = map_filename
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "TrackFilenameList"] = [track_filename]
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"]["StartingOffsetMs"] = 0

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=3)

        scenario = scenario_generation.get_scenario(2)
        self.assertAlmostEqual(scenario.eval_agent_ids, [3])
        world_state = scenario.GetWorldState()
        agent31 = world_state.GetAgent(1)
        agent32 = world_state.GetAgent(2)
        agent33 = world_state.GetAgent(3)

        # they all should be valid at the beginning
        world_state.time = 0
        self.assertEqual(isinstance(agent31, Agent), True)
        self.assertEqual(agent31.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent31.InsideRoadCorridor(), True)

        self.assertEqual(isinstance(agent32, Agent), True)
        self.assertEqual(agent32.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent32.InsideRoadCorridor(), True)

        self.assertEqual(isinstance(agent33, Agent), True)
        self.assertEqual(agent33.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent33.InsideRoadCorridor(), True)

        world_state.Step(0.05)
        self.assertEqual(len(world_state.agents_valid), 3)
    def test_dataset_scenario_generation_full_incomplete(self):
        params = ParameterServer()

        map_filename =  os.path.join(os.path.dirname(__file__), "data/DR_CHN_Merging_ZS_partial_v02.xodr")
        track_filename =  os.path.join(os.path.dirname(__file__), "data/interaction_dataset_dummy_track_incomplete.csv")

        params["Scenario"]["Generation"]["InteractionDatasetScenarioGenerationFull"]["MapFilename"] = map_filename
        params["Scenario"]["Generation"]["InteractionDatasetScenarioGenerationFull"]["TrackFilenameList"] = [track_filename]

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=3)
        # agent 1 is not part of the map, so it should only generate 2 scenarios

        self.assertEqual(scenario_generation.get_num_scenarios(), 2)
    def test_setting_behavior_of_ego_agent(self):
        params = ParameterServer()

        map_filename = os.path.join(os.path.dirname(__file__),
                                    "data/DR_DEU_Merging_MT_v01_shifted.xodr")
        track_filename = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_dummy_track.csv")

        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "MapFilename"] = map_filename
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "TrackFilenameList"] = [track_filename]

        # Set behaviour model of the ego agent
        params["Scenario"]["Generation"]["InteractionDatasetScenarioGenerationFull"]["BehaviorModel"] = \
          {"ego": "BehaviorMobilRuleBased"}

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=2)
        for scenario, _ in scenario_generation:
            ego_id = scenario.eval_agent_ids[
                0]  # Assume there is only one ego agent
            for agent in scenario._agent_list:
                if agent.id == ego_id:
                    behavior_model = agent.behavior_model
                    self.assertEqual(
                        str(behavior_model).rsplit(".")[-1],
                        "BehaviorMobilRuleBased")
                    break
            else:
                # No ego agent in scenario._agent_list, something is wrong
                self.assertTrue(False)
    def test_excluded_tracks(self):
        params = ParameterServer()

        map_filename = os.path.join(os.path.dirname(__file__),
                                    "data/DR_DEU_Merging_MT_v01_shifted.xodr")
        track_filename_1 = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_dummy_track.csv")
        track_filename_2 = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_dummy_track_2.csv")

        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "MapFilename"] = map_filename
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "TrackFilenameList"] = [track_filename_1, track_filename_2]
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"]["ExcludeTracks"] = {
                track_filename_1: [2],
                track_filename_2: [1, 3],
            }

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=10)
        for scenario, _ in scenario_generation:
            track_filename = scenario.json_params["track_file"]
            ego_id = scenario.eval_agent_ids[0]

            # Check that pair (track_filename, ego_id) is NOT in ExcludeTracks
            self.assertNotIn((track_filename, ego_id), [(track_filename_1, 2),
                                                        (track_filename_2, 1),
                                                        (track_filename_2, 3)])
    def test_dataset_scenario_generation_full_incomplete(self):
        """
    Checking Track file with three track ids, but agent 1 is never inside the map, 
    so it should only generate 2 scenarios
    """
        params = ParameterServer()

        map_filename = os.path.join(os.path.dirname(__file__),
                                    "data/DR_CHN_Merging_ZS_partial_v02.xodr")
        track_filename = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_CHN_Merging_dummy_track_incomplete.csv")

        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "MapFilename"] = map_filename
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "TrackFilenameList"] = [track_filename]
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"]["StartingOffsetMs"] = 0

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=3)
        self.assertEqual(scenario_generation.get_num_scenarios(), 2)

        # first scenario
        agent12 = scenario_generation.get_scenario(0).GetWorldState().agents[2]
        agent17 = scenario_generation.get_scenario(0).GetWorldState().agents[7]
        self.assertEqual(agent12.first_valid_timestamp, 0.0)
        self.assertEqual(agent17.first_valid_timestamp, 0.0)

        # second scenario
        agent22 = scenario_generation.get_scenario(1).GetWorldState().agents[2]
        agent27 = scenario_generation.get_scenario(1).GetWorldState().agents[7]
        self.assertEqual(agent22.first_valid_timestamp, 0.0)
        self.assertEqual(agent27.first_valid_timestamp, 0.0)
    def test_included_tracks(self):
        params = ParameterServer()

        map_filename = os.path.join(os.path.dirname(__file__),
                                    "data/DR_DEU_Merging_MT_v01_shifted.xodr")
        track_filename_1 = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_dummy_track.csv")
        track_filename_2 = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_dummy_track_2.csv")

        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "MapFilename"] = map_filename
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "TrackFilenameList"] = [track_filename_1, track_filename_2]
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"]["IncludeTracks"] = {
                track_filename_1: [2],
                track_filename_2: [1, 3],
            }

        def assert_correct_combinations(scenario_generation):
            for scenario, _ in scenario_generation:
                track_filename = scenario.json_params["track_file"]
                ego_id = scenario.eval_agent_ids[0]

                # Check if pair (track_filename, ego_id) is in IncludeTracks
                self.assertIn((track_filename, ego_id),
                              [(track_filename_1, 2), (track_filename_2, 1),
                               (track_filename_2, 3)])

        # CASE 1: num_scenarios < number of scenarios in IncludeTracks
        scenario_generation_1 = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=2)
        # Only two scenarios should be generated (num_scenarios=2)
        self.assertEqual(scenario_generation_1.get_num_scenarios(), 2)
        assert_correct_combinations(scenario_generation_1)

        # CASE 2: num_scenarios > number of scenarios in IncludeTracks
        scenario_generation_2 = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=4)
        # Three scenarios should be generated (all specified in IncludeTracks)
        self.assertEqual(scenario_generation_2.get_num_scenarios(), 3)
        assert_correct_combinations(scenario_generation_2)
    def test_dataset_scenario_generation_full_outside3_behavior_overwritten(
            self):
        """
    Checking Track file with three track ids. test wether agent 3 (the ego agent 
    of the scenario) outside at the beginning is correctly identified as valid at first 
    world time step (although in the track files, it becomes valid later than the other agents 
    -> will be cut off for scenario). Overwriting Behavior Model.
    """
        params = ParameterServer()

        map_filename = os.path.join(os.path.dirname(__file__),
                                    "data/DR_DEU_Merging_MT_v01_shifted.xodr")
        track_filename = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_DEU_Merging_dummy_track_outside.csv")

        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "MapFilename"] = map_filename
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "TrackFilenameList"] = [track_filename]
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"]["StartingOffsetMs"] = 0
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "BehaviorModel"] = "BehaviorMobilRuleBased"

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=3)

        scenario = scenario_generation.get_scenario(2)
        self.assertAlmostEqual(scenario.eval_agent_ids, [3])
        world_state = scenario.GetWorldState()
        agent31 = world_state.GetAgent(1)
        agent32 = world_state.GetAgent(2)
        agent33 = world_state.GetAgent(3)

        # others are defined as BehaviorMobilRuleBased
        self.assertTrue(
            isinstance(agent31.behavior_model, BehaviorMobilRuleBased))
        self.assertTrue(
            isinstance(agent32.behavior_model, BehaviorMobilRuleBased))
        self.assertTrue(
            isinstance(agent33.behavior_model, BehaviorStaticTrajectory))

        # they all should be valid at the beginning
        world_state.time = 0
        self.assertEqual(isinstance(agent31, Agent), True)
        self.assertEqual(agent31.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent31.InsideRoadCorridor(), True)

        self.assertEqual(isinstance(agent32, Agent), True)
        self.assertEqual(agent32.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent32.InsideRoadCorridor(), True)

        self.assertEqual(isinstance(agent33, Agent), True)
        self.assertEqual(agent33.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent33.InsideRoadCorridor(), True)

        self.assertEqual(isinstance(agent33, Agent), True)
        self.assertEqual(agent33.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent33.InsideRoadCorridor(), True)

        evaluator = EvaluatorCollisionAgents()
        world_state.AddEvaluator("collision", evaluator)
        info = world_state.Evaluate()
        self.assertEqual(info["collision"], False)

        world_state.Step(0.05)

        evaluator = EvaluatorCollisionAgents()
        world_state.AddEvaluator("collision", evaluator)
        info = world_state.Evaluate()
        self.assertEqual(info["collision"], False)

        self.assertEqual(len(world_state.agents_valid), 3)
    def test_dataset_scenario_generation_full_outside1_behavior_overwritten(
            self):
        """
    Checking Track file with three track ids. test wether agent 3 (not the ego agent 
    of the scenario) outside at the beginning is correctly identified as invalid at first 
    world time step, but becomes valid later. Overwriting Behavior Model.
    """
        params = ParameterServer()

        map_filename = os.path.join(os.path.dirname(__file__),
                                    "data/DR_DEU_Merging_MT_v01_shifted.xodr")
        track_filename = os.path.join(
            os.path.dirname(__file__),
            "data/interaction_dataset_DEU_Merging_dummy_track_outside.csv")

        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "MapFilename"] = map_filename
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "TrackFilenameList"] = [track_filename]
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"]["StartingOffsetMs"] = 0
        params["Scenario"]["Generation"][
            "InteractionDatasetScenarioGenerationFull"][
                "BehaviorModel"] = "BehaviorMobilRuleBased"

        scenario_generation = InteractionDatasetScenarioGenerationFull(
            params=params, num_scenarios=1)

        scenario = scenario_generation.get_scenario(0)
        self.assertAlmostEqual(scenario.eval_agent_ids, [1])

        world_state = scenario.GetWorldState()
        agent11 = world_state.GetAgent(1)
        agent12 = world_state.GetAgent(2)
        agent13 = world_state.GetAgent(3)

        # others are defined as BehaviorMobilRuleBased
        self.assertTrue(
            isinstance(agent11.behavior_model, BehaviorStaticTrajectory))
        self.assertTrue(
            isinstance(agent12.behavior_model, BehaviorMobilRuleBased))
        self.assertTrue(
            isinstance(agent13.behavior_model, BehaviorMobilRuleBased))

        self.assertAlmostEqual(agent11.first_valid_timestamp, 0.0)
        self.assertAlmostEqual(agent12.first_valid_timestamp, 0.0)
        self.assertNotEqual(agent13.first_valid_timestamp, 0.0)

        # agent13 should not be valid at the beginning, as he is outside of map
        world_state.time = 0
        self.assertEqual(isinstance(agent11, Agent), True)
        self.assertEqual(agent11.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent11.InsideRoadCorridor(), True)

        self.assertEqual(isinstance(agent12, Agent), True)
        self.assertEqual(agent12.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent12.InsideRoadCorridor(), True)

        self.assertEqual(isinstance(agent13, Agent), True)
        self.assertEqual(agent13.IsValidAtTime(world_state.time), False)
        # as we use only state once it's in map, this will be true, although the time step is not valid yet
        self.assertEqual(agent13.InsideRoadCorridor(), True)

        # agent13 should not be valid at the beginning, as he is outside of map
        world_state.Step(0.05)

        self.assertEqual(isinstance(agent11, Agent), True)
        self.assertEqual(agent11.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent11.InsideRoadCorridor(), True)

        self.assertEqual(isinstance(agent12, Agent), True)
        self.assertEqual(agent12.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent12.InsideRoadCorridor(), True)

        self.assertEqual(isinstance(agent13, Agent), True)
        self.assertEqual(agent13.IsValidAtTime(world_state.time), False)
        # as we use only state once it's in map, this will be true, although the time step is not valid yet
        self.assertEqual(agent13.InsideRoadCorridor(), True)

        self.assertEqual(list(world_state.agents_valid.keys()), [1, 2])

        # agent13 should be valid at some point
        world_state.Step(agent13.first_valid_timestamp)
        world_state.Step(
            0.01
        )  # agent13.IsValidAtTime() uses previous time stamp, therefore we increment it one more step

        self.assertEqual(isinstance(agent11, Agent), True)
        self.assertEqual(agent11.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent11.InsideRoadCorridor(), True)

        self.assertEqual(isinstance(agent12, Agent), True)
        self.assertEqual(agent12.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent12.InsideRoadCorridor(), True)

        self.assertEqual(isinstance(agent13, Agent), True)
        self.assertEqual(agent13.IsValidAtTime(world_state.time), True)
        self.assertEqual(agent13.InsideRoadCorridor(), True)

        self.assertEqual(list(world_state.agents_valid.keys()), [1, 2, 3])