def test_add_step(): logger = Logger() list_records = [Record(1.0), Record(1.0), Record(1.0), Record(1.0)] for ite, record in enumerate(list_records): logger.add_steps(record) assert ite + 1 == len(logger.current_steps) assert record == logger.current_steps[-1]
def test_record_init(): test_fail_values = [None, "deeee", [], {}, object()] test_values = [1, 1.0, 0.0, 0, 125, 125.125] for value in test_fail_values: with pytest.raises(TypeError): Record(value=value) for value in test_values: record = Record(value=value) assert record.value == value
def do_step(self, observation, learn=True, logger=None, render=True): """ :param observation: :param learn: :param logger: :param render: if show env render :type render: bool :return: """ if render: self.render() action = self.agent.get_action(observation=observation) next_observation, reward, done, info = self.environment.step(action) if learn: self.agent.learn(observation, action, reward, next_observation, done) if logger: logger.add_steps(Record(reward)) return next_observation, done, reward
def test_max_records(): list_records = [[Record(1), Record(1), Record(1), Record(1)], [Record(2), Record(2), Record(2), Record(2)], [Record(1.0), Record(1.0), Record(1.0), Record(1.0)], [Record(2.0), Record(2.0), Record(2.0), Record(2.0)], [Record(1), Record(2), Record(3), Record(4)], [Record(-1), Record(1)], [Record(-10), Record(-15), Record(-20), Record(-15)]] list_value = [1, 2, 1.0, 2.0, 4, 1, -10] for records, value in zip(list_records, list_value): assert value == Record.max_records(records) list_fail = ["dsdzs", ["dzdqzdq"], [1548, 1548], 1254] for records in list_fail: with pytest.raises(TypeError): Record.max_records(records) assert 0 == Record.max_records([])
def test_write_log(): list_steps = [[Record(1), Record(1), Record(1), Record(1)], [Record(2), Record(2), Record(2), Record(2)], [Record(1.0), Record(1.0), Record(1.0), Record(1.0)], [Record(2.0), Record(2.0), Record(2.0), Record(2.0)], [Record(1), Record(2), Record(3), Record(4)], [Record(-1), Record(1)], [Record(-10), Record(-15), Record(-20), Record(-15)]] for ite, records in enumerate(list_steps): Logger.write_log("./runs", records, ite)
def test_log_episode(): summary_writer = FakeSummaryWriter() list_steps = [[Record(1), Record(1), Record(1), Record(1)], [Record(2), Record(2), Record(2), Record(2)], [Record(1.0), Record(1.0), Record(1.0), Record(1.0)], [Record(2.0), Record(2.0), Record(2.0), Record(2.0)], [Record(1), Record(2), Record(3), Record(4)], [Record(-1), Record(1)], [Record(-10), Record(-15), Record(-20), Record(-15)]] for ite, records in enumerate(list_steps): Logger.log_episode(summary_writer, records, ite) assert (ite + 1) * 4 == len(summary_writer.add_scalar_call) assert ite == summary_writer.add_scalar_call[-1][2]
def test_evaluate(): logger = Logger() list_steps = [[Record(1), Record(1), Record(1), Record(1)], [Record(2), Record(2), Record(2), Record(2)], [Record(1.0), Record(1.0), Record(1.0), Record(1.0)], [Record(2.0), Record(2.0), Record(2.0), Record(2.0)], [Record(1), Record(2), Record(3), Record(4)], [Record(-1), Record(1)], [Record(-10), Record(-15), Record(-20), Record(-15)]] for ite, steps in enumerate(list_steps): logger.current_steps = steps logger.evaluate() assert 0 == len(logger.episodes)
def test_add_episode(): logger = Logger() list_episodes = [[Record(1), Record(1), Record(1), Record(1)], [Record(2), Record(2), Record(2), Record(2)], [Record(1.0), Record(1.0), Record(1.0), Record(1.0)], [Record(2.0), Record(2.0), Record(2.0), Record(2.0)], [Record(1), Record(2), Record(3), Record(4)], [Record(-1), Record(1)], [Record(-10), Record(-15), Record(-20), Record(-15)]] for ite, episode in enumerate(list_episodes): logger.add_episode(episode) assert ite + 1 == len(logger.episodes) assert episode == logger.episodes[-1]