def test_collect_demonstrations(self): params = ParameterServer() bp = DiscreteHighwayBlueprint(params, number_of_senarios=10, random_seed=0) env = SingleAgentRuntime(blueprint=bp, render=False) env._observer = NearestAgentsObserver(params) env._action_wrapper = BehaviorDiscreteMacroActionsML(params) env._evaluator = TestEvaluator() demo_behavior = bark_ml.library_wrappers.lib_fqf_iqn_qrdqn.\ tests.test_demo_behavior.TestDemoBehavior(params) collector = DemonstrationCollector() collection_result = collector.CollectDemonstrations(env, demo_behavior, 4, "./test_demo_collected", \ use_mp_runner=False, runner_init_params={"deepcopy" : False}) self.assertTrue( os.path.exists("./test_demo_collected/collection_result")) print(collection_result.get_data_frame().to_string()) experiences = collector.ProcessCollectionResult( eval_criteria={"goal_r1": lambda x: x}) # expected length = 2 scenarios (only every second reaches goal) x 3 steps (4 executed, but first not counted) self.assertEqual(len(experiences), 2 * 3) collector.dump("./final_collections") loaded_collector = DemonstrationCollector.load("./final_collections") experiences_loaded = loaded_collector.GetDemonstrationExperiences() print(experiences_loaded) self.assertEqual(len(experiences_loaded), 2 * 3)
def test_iqn_agent(self): params = ParameterServer() params["ML"]["BaseAgent"]["NumSteps"] = 2 params["ML"]["BaseAgent"]["MaxEpisodeSteps"] = 2 bp = DiscreteHighwayBlueprint(params, number_of_senarios=10, random_seed=0) env = SingleAgentRuntime(blueprint=bp, render=False) env._observer = NearestAgentsObserver(params) env._action_wrapper = BehaviorDiscreteMacroActionsML(params) iqn_agent = IQNAgent(agent_save_dir="./save_dir", env=env, params=params) iqn_agent.train_episode() iqn_agent.save(checkpoint_type="best") iqn_agent.save(checkpoint_type="last") loaded_agent = IQNAgent(agent_save_dir="./save_dir", checkpoint_load="best") loaded_agent2 = IQNAgent(agent_save_dir="./save_dir", checkpoint_load="last") loaded_agent_with_env = IQNAgent(env=env, agent_save_dir="./save_dir", checkpoint_load="last") loaded_agent_with_env.train_episode() self.assertEqual(loaded_agent.ml_behavior.action_space.n, iqn_agent.ml_behavior.action_space.n) return