def test_trainer(r): config = Discrete('LunarLander-v2') db = Db(redis_client=r) rollout = db.create_rollout(config.data.coder.construct()) g = Gatherer() t = Trainer() tst = ListenTestServer(r) uuid = uuid4() try: ServerThread(g).start() ServerThread(t).start() ServerThread(tst).start() random_policy = config.random_policy.construct() actor = config.actor.construct() critic = config.critic.construct() RolloutMessage(uuid, 'LunarLander-v2_xxx', rollout.id, random_policy, config, 1).send(r) sleep(3) assert len(rollout) >= config.gatherer.num_steps_per_rollout TrainMessage(uuid, actor, critic, config).send(r) sleep(50) finally: db.delete_rollout(rollout) ExitMessage(uuid).send(r)
def __init__(self, redis_host='localhost', redis_port=6379, redis_db=0, redis_password=None): super().__init__(redis_host, redis_port, redis_db, redis_password) self.handler.register(RolloutMessage, self.rollout) self.rollout_thread = None self.exp_buffer = Db(host=redis_host, port=redis_port, password=redis_password) self.redis = self.exp_buffer.redis self.job = None logger.info('Init Complete')
def __init__(self, redis_host='localhost', redis_port=6379, redis_db=0, redis_password=None): super().__init__(redis_host, redis_port, redis_db, redis_password) self.handler.register(TrainMessage, self.handle_train) self.db = Db(host=redis_host, port=redis_port, db=redis_db, password=redis_password) logger.info('Init Complete')
def rollout_policy(num_episodes, policy, config, capsys=None, render=False, render_freq=1, redis_host='localhost', redis_port=6379): policy = policy.eval() policy = policy.to('cpu') db = Db(host=redis_host, port=redis_port, db=1) rollout = db.create_rollout(config.data.coder.construct()) v = UniImageViewer(config.env.name, (200, 160)) env = config.env.construct() rewards = [] for i in range(num_episodes): render_iter = (i % render_freq == 0) and render episode = single_episode(env, config, policy, rollout, render=render_iter) rewards.append(episode.total_reward()) rollout.finalize() logger.info(f'ave reward {mean(rewards)}') return rollout
def test_gatherer_processing(r): config = Discrete('LunarLander-v2') db = Db(redis_client=r) rollout = db.create_rollout(config.data.coder) s = Gatherer() tst = ListenTestServer(r) try: ServerThread(s).start() ServerThread(tst).start() random_policy = config.random_policy.construct() RolloutMessage(s.id, 'LunarLander-v2_xxx', rollout.id, random_policy, config, 1).send(r) sleep(5) assert len(rollout) >= config.gatherer.num_steps_per_rollout finally: db.delete_rollout(rollout) ExitMessage(s.id).send(r)
def __init__(self, redis_host='localhost', redis_port=6379, redis_db=0, redis_password=None, db_host='localhost', db_port=5432, db_name='policy_db', db_user='******', db_password=None): super().__init__(redis_host, redis_port, redis_db, redis_password) self.db = PolicyDB(db_host=db_host, db_port=db_port, db_name=db_name, db_user=db_user, db_password=db_password) self.exp_buffer = Db(host=redis_host, port=redis_port, db=redis_db, password=redis_password) self.db_host = db_host self.db_port = db_port self.db_name = db_name self.db_user = db_user self.db_password = db_password self.handler.register(StartMessage, self.handle_start) self.handler.register(StopMessage, self.handle_stop) self.handler.register(EpisodeMessage, self.handle_episode) self.handler.register(TrainCompleteMessage, self.handle_train_complete) self.handler.register(ConfigUpdateMessage, self.handle_config_update) self.config = configs.default self.actor = None self.critic = None self.resume(self.db) self.last_active = SharedTimestamp(self.r) self.start_heartbeat(5, self.heartbeat, redis=self.r) logger.info('Init Complete')
parser.add_argument("-pa", "--postgres-password", help='password of postgres server', dest='postgres_password', default='password') args = parser.parse_args() config_list = { 'CartPole-v0': configs.Discrete('CartPole-v0'), 'LunarLander-v2': configs.Discrete('LunarLander-v2'), 'Acrobot-v1': configs.Discrete('Acrobot-v1'), 'MountainCar-v0': configs.Discrete('MountainCar-v0'), 'HalfCheetah-v1': configs.Continuous('RoboschoolHalfCheetah-v1'), 'Hopper-v0': configs.Continuous('RoboschoolHopper-v1') } r = Redis(host=args.redis_host, port=args.redis_port, password=args.redis_password) exp_buffer = Db(host=args.redis_host, port=args.redis_port, password=args.redis_password) policy_db = PolicyDB(args.postgres_host, args.postgres_password, args.postgres_user, args.postgres_db, args.postgres_port) gui_uuid = uuid.uuid4() services = MicroServiceBuffer() PingMessage(gui_uuid).send(r) gatherers = {} gatherers_progress = {} gatherers_progress_epi = {} next_free_slot = 0 handler = MessageHandler(r, 'rollout') handler.register(EpisodeMessage, episode)