예제 #1
0
def test_log_episode():
    summary_writer = FakeSummaryWriter()
    list_steps = [[Record(1), Record(1),
                   Record(1), Record(1)],
                  [Record(2), Record(2),
                   Record(2), Record(2)],
                  [Record(1.0),
                   Record(1.0),
                   Record(1.0),
                   Record(1.0)],
                  [Record(2.0),
                   Record(2.0),
                   Record(2.0),
                   Record(2.0)], [Record(1),
                                  Record(2),
                                  Record(3),
                                  Record(4)], [Record(-1),
                                               Record(1)],
                  [Record(-10),
                   Record(-15),
                   Record(-20),
                   Record(-15)]]

    for ite, records in enumerate(list_steps):
        Logger.log_episode(summary_writer, records, ite)
        assert (ite + 1) * 4 == len(summary_writer.add_scalar_call)
        assert ite == summary_writer.add_scalar_call[-1][2]
예제 #2
0
def test_evaluate():
    logger = Logger()
    list_steps = [[Record(1), Record(1),
                   Record(1), Record(1)],
                  [Record(2), Record(2),
                   Record(2), Record(2)],
                  [Record(1.0),
                   Record(1.0),
                   Record(1.0),
                   Record(1.0)],
                  [Record(2.0),
                   Record(2.0),
                   Record(2.0),
                   Record(2.0)], [Record(1),
                                  Record(2),
                                  Record(3),
                                  Record(4)], [Record(-1),
                                               Record(1)],
                  [Record(-10),
                   Record(-15),
                   Record(-20),
                   Record(-15)]]

    for ite, steps in enumerate(list_steps):
        logger.current_steps = steps
        logger.evaluate()
        assert 0 == len(logger.episodes)
예제 #3
0
def test_add_episode():
    logger = Logger()
    list_episodes = [[Record(1), Record(1),
                      Record(1), Record(1)],
                     [Record(2), Record(2),
                      Record(2), Record(2)],
                     [Record(1.0),
                      Record(1.0),
                      Record(1.0),
                      Record(1.0)],
                     [Record(2.0),
                      Record(2.0),
                      Record(2.0),
                      Record(2.0)],
                     [Record(1), Record(2),
                      Record(3), Record(4)], [Record(-1),
                                              Record(1)],
                     [Record(-10),
                      Record(-15),
                      Record(-20),
                      Record(-15)]]
    for ite, episode in enumerate(list_episodes):
        logger.add_episode(episode)
        assert ite + 1 == len(logger.episodes)
        assert episode == logger.episodes[-1]
예제 #4
0
def test_add_step():
    logger = Logger()
    list_records = [Record(1.0), Record(1.0), Record(1.0), Record(1.0)]
    for ite, record in enumerate(list_records):
        logger.add_steps(record)
        assert ite + 1 == len(logger.current_steps)
        assert record == logger.current_steps[-1]
예제 #5
0
def test_logger_init():
    logger = Logger()
    assert not logger.episodes and not logger.current_steps
    assert isinstance(logger.summary_writer, SummaryWriter)

    logger = Logger(log_dir="des")
    assert not logger.episodes and not logger.current_steps
    assert isinstance(logger.summary_writer, SummaryWriter)
    assert logger.summary_writer.log_dir == "des"
예제 #6
0
def test_write_log():
    list_steps = [[Record(1), Record(1),
                   Record(1), Record(1)],
                  [Record(2), Record(2),
                   Record(2), Record(2)],
                  [Record(1.0),
                   Record(1.0),
                   Record(1.0),
                   Record(1.0)],
                  [Record(2.0),
                   Record(2.0),
                   Record(2.0),
                   Record(2.0)], [Record(1),
                                  Record(2),
                                  Record(3),
                                  Record(4)], [Record(-1),
                                               Record(1)],
                  [Record(-10),
                   Record(-15),
                   Record(-20),
                   Record(-15)]]

    for ite, records in enumerate(list_steps):
        Logger.write_log("./runs", records, ite)
예제 #7
0
    def __init__(self, environment, agent, log_dir="./runs"):
        """

        :param environment:
        :param agent:
        :param log_dir:
        """
        self.environment = self.get_environment(environment)
        if isinstance(agent, type(AgentInterface)):
            action_space = self.get_environment(environment).action_space
            observation_space = self.get_environment(environment).observation_space
            self.agent = agent(observation_space=observation_space, action_space=action_space)
        elif isinstance(agent, AgentInterface):
            import warnings
            warnings.warn("be sure of your agent need to have good input and output dimension")
            self.agent = agent
        else:
            raise TypeError("this type (" + str(type(agent)) + ") is an AgentInterface or instance of AgentInterface")

        self.logger = Logger(log_dir=log_dir)