def test_box_logging(self):
        temp_dir = tempfile.TemporaryDirectory()

        seed = 0
        episodes = 10
        logger = Logger(
            output_path=Path(temp_dir.name),
            experiment_name="test_box_logging",
            step_write_frequency=None,
            episode_write_frequency=1,
        )

        bench = LubyBenchmark()
        bench.set_seed(seed)
        env = bench.get_environment()
        state_logger = logger.add_module(StateTrackingWrapper)
        wrapped = StateTrackingWrapper(env, logger=state_logger)
        agent = StaticAgent(env, 1)
        logger.set_env(env)

        run_benchmark(wrapped, agent, episodes, logger)
        state_logger.close()

        logs = load_logs(state_logger.get_logfile())
        dataframe = log2dataframe(logs, wide=True)

        sate_columns = [
            "state_Action t (current)",
            "state_Step t (current)",
            "state_Action t-1",
            "state_Action t-2",
            "state_Step t-1",
            "state_Step t-2",
        ]

        for state_column in sate_columns:
            self.assertTrue(state_column in dataframe.columns)
            self.assertTrue((~dataframe[state_column].isna()).all())

        temp_dir.cleanup()
Пример #2
0
    def test_logging_discrete(self):

        temp_dir = tempfile.TemporaryDirectory()

        seed = 0
        logger = Logger(
            output_path=Path(temp_dir.name),
            experiment_name="test_discrete_logging",
            step_write_frequency=None,
            episode_write_frequency=1,
        )

        bench = LubyBenchmark()
        bench.set_seed(seed)
        env = bench.get_environment()
        env.seed_action_space(seed)

        action_logger = logger.add_module(ActionFrequencyWrapper)
        wrapped = ActionFrequencyWrapper(env, logger=action_logger)
        agent = RandomAgent(env)
        logger.set_env(env)

        run_benchmark(wrapped, agent, 10, logger)
        action_logger.close()

        logs = load_logs(action_logger.get_logfile())
        dataframe = log2dataframe(logs, wide=True)

        expected_actions = [
            0,
            3,
            5,
            4,
            3,
            5,
            5,
            5,
            3,
            3,
            2,
            1,
            0,
            1,
            2,
            0,
            1,
            1,
            0,
            1,
            2,
            4,
            3,
            0,
            1,
            3,
            0,
            3,
            3,
            3,
            4,
            4,
            4,
            5,
            4,
            0,
            4,
            2,
            1,
            3,
            4,
            2,
            1,
            3,
            3,
            2,
            0,
            5,
            2,
            5,
            2,
            1,
            5,
            3,
            2,
            5,
            1,
            0,
            2,
            3,
            1,
            3,
            2,
            3,
            2,
            4,
            3,
            4,
            0,
            5,
            5,
            1,
            5,
            0,
            1,
            5,
            5,
            3,
            3,
            2,
        ]

        self.assertListEqual(dataframe.action.to_list(), expected_actions)

        temp_dir.cleanup()