def test_sum_records(): list_records = [[Record(1), Record(1), Record(1), Record(1)], [Record(2), Record(2), Record(2), Record(2)], [Record(1.0), Record(1.0), Record(1.0), Record(1.0)], [Record(2.0), Record(2.0), Record(2.0), Record(2.0)], [Record(1), Record(2), Record(3), Record(4)], [Record(-1), Record(1)], [Record(-10), Record(-15), Record(-20), Record(-15)]] list_value = [4, 8, 4.0, 8.0, 10, 0, -60] for records, value in zip(list_records, list_value): assert value == Record.sum_records(records) list_fail = ["dsdzs", ["dzdqzdq"], [1548, 1548], 1254] for records in list_fail: with pytest.raises(TypeError): Record.sum_records(records) assert 0 == Record.sum_records([])
from blobrl import Trainer, Record from blobrl.agents import CategoricalDQN, DQN, DoubleDQN import gym if __name__ == "__main__": for agent in [CategoricalDQN, DQN, DoubleDQN]: env = gym.make("CartPole-v1") a = agent(env.observation_space, env.action_space) trainer = Trainer(environment=env, agent=agent) for i in range(100): trainer.train(max_episode=50, render=False, nb_evaluation=0) m = max([Record.sum_records(e) for e in trainer.logger.episodes]) print(agent.__name__, i, m) if m > 200: break print("####### ", agent.__name__, i, m, " #######")