Exemplo n.º 1
0
def test_trainer(r):

    config = Discrete('LunarLander-v2')
    db = Db(redis_client=r)
    rollout = db.create_rollout(config.data.coder.construct())
    g = Gatherer()
    t = Trainer()
    tst = ListenTestServer(r)
    uuid = uuid4()

    try:
        ServerThread(g).start()
        ServerThread(t).start()
        ServerThread(tst).start()

        random_policy = config.random_policy.construct()
        actor = config.actor.construct()
        critic = config.critic.construct()

        RolloutMessage(uuid, 'LunarLander-v2_xxx', rollout.id, random_policy, config, 1).send(r)

        sleep(3)
        assert len(rollout) >= config.gatherer.num_steps_per_rollout
        TrainMessage(uuid, actor, critic, config).send(r)

        sleep(50)

    finally:
        db.delete_rollout(rollout)
        ExitMessage(uuid).send(r)
Exemplo n.º 2
0
    def __init__(self, redis_host='localhost', redis_port=6379, redis_db=0, redis_password=None):
        super().__init__(redis_host, redis_port, redis_db, redis_password)
        self.handler.register(RolloutMessage, self.rollout)
        self.rollout_thread = None
        self.exp_buffer = Db(host=redis_host, port=redis_port, password=redis_password)
        self.redis = self.exp_buffer.redis
        self.job = None

        logger.info('Init Complete')
Exemplo n.º 3
0
    def __init__(self,
                 redis_host='localhost',
                 redis_port=6379,
                 redis_db=0,
                 redis_password=None):
        super().__init__(redis_host, redis_port, redis_db, redis_password)
        self.handler.register(TrainMessage, self.handle_train)
        self.db = Db(host=redis_host,
                     port=redis_port,
                     db=redis_db,
                     password=redis_password)

        logger.info('Init Complete')
Exemplo n.º 4
0
class Trainer(Server):
    def __init__(self,
                 redis_host='localhost',
                 redis_port=6379,
                 redis_db=0,
                 redis_password=None):
        super().__init__(redis_host, redis_port, redis_db, redis_password)
        self.handler.register(TrainMessage, self.handle_train)
        self.db = Db(host=redis_host,
                     port=redis_port,
                     db=redis_db,
                     password=redis_password)

        logger.info('Init Complete')

    def handle_train(self, msg):

        trainer = msg.config.algo.construct()

        exp_buffer = self.db.latest_rollout(msg.config.data.coder.construct())
        assert len(exp_buffer) != 0

        logger.info('started training')
        actor, critic = trainer(msg.actor, msg.critic, exp_buffer, msg.config)

        logging.info('training complete')
        TrainCompleteMessage(self.id, actor, critic, msg.config).send(self.r)
Exemplo n.º 5
0
class Gatherer(Server):
    def __init__(self, redis_host='localhost', redis_port=6379, redis_db=0, redis_password=None):
        super().__init__(redis_host, redis_port, redis_db, redis_password)
        self.handler.register(RolloutMessage, self.rollout)
        self.rollout_thread = None
        self.exp_buffer = Db(host=redis_host, port=redis_port, password=redis_password)
        self.redis = self.exp_buffer.redis
        self.job = None

        logger.info('Init Complete')

    def rollout(self, msg):
        logger.debug(f'gathering for rollout: {msg.rollout_id}')
        policy = msg.policy.to('cpu').eval()
        env = msg.config.env.construct()
        rollout = self.exp_buffer.rollout(msg.rollout_id, msg.config.data.coder.construct())
        episode_number = 0

        while self.exp_buffer.rollout_seq.current() == msg.rollout_id and len(
                rollout) < msg.config.gatherer.num_steps_per_rollout:
            logger.info(f'starting episode {episode_number} of {msg.config.env.name}')
            episode = single_episode(env, msg.config, policy, rollout)

            epi_mess = EpisodeMessage(self.id, msg.run, episode_number, len(episode), episode.total_reward(),
                                      msg.config.gatherer.num_steps_per_rollout)
            epi_mess.monitor['entropy'] = episode.entropy
            epi_mess.send(self.r)

            episode_number += 1
Exemplo n.º 6
0
def rollout_policy(num_episodes, policy, config, capsys=None, render=False, render_freq=1, redis_host='localhost',
                   redis_port=6379):
    policy = policy.eval()
    policy = policy.to('cpu')
    db = Db(host=redis_host, port=redis_port, db=1)
    rollout = db.create_rollout(config.data.coder.construct())
    v = UniImageViewer(config.env.name, (200, 160))
    env = config.env.construct()
    rewards = []

    for i in range(num_episodes):
        render_iter = (i % render_freq == 0) and render
        episode = single_episode(env, config, policy, rollout, render=render_iter)
        rewards.append(episode.total_reward())

    rollout.finalize()
    logger.info(f'ave reward {mean(rewards)}')
    return rollout
Exemplo n.º 7
0
def test_gatherer_processing(r):
    config = Discrete('LunarLander-v2')
    db = Db(redis_client=r)
    rollout = db.create_rollout(config.data.coder)
    s = Gatherer()
    tst = ListenTestServer(r)

    try:
        ServerThread(s).start()
        ServerThread(tst).start()

        random_policy = config.random_policy.construct()

        RolloutMessage(s.id, 'LunarLander-v2_xxx', rollout.id, random_policy, config, 1).send(r)
        sleep(5)
        assert len(rollout) >= config.gatherer.num_steps_per_rollout

    finally:
        db.delete_rollout(rollout)
        ExitMessage(s.id).send(r)
Exemplo n.º 8
0
    def __init__(self,
                 redis_host='localhost',
                 redis_port=6379,
                 redis_db=0,
                 redis_password=None,
                 db_host='localhost',
                 db_port=5432,
                 db_name='policy_db',
                 db_user='******',
                 db_password=None):
        super().__init__(redis_host, redis_port, redis_db, redis_password)
        self.db = PolicyDB(db_host=db_host,
                           db_port=db_port,
                           db_name=db_name,
                           db_user=db_user,
                           db_password=db_password)
        self.exp_buffer = Db(host=redis_host,
                             port=redis_port,
                             db=redis_db,
                             password=redis_password)

        self.db_host = db_host
        self.db_port = db_port
        self.db_name = db_name
        self.db_user = db_user
        self.db_password = db_password

        self.handler.register(StartMessage, self.handle_start)
        self.handler.register(StopMessage, self.handle_stop)
        self.handler.register(EpisodeMessage, self.handle_episode)
        self.handler.register(TrainCompleteMessage, self.handle_train_complete)
        self.handler.register(ConfigUpdateMessage, self.handle_config_update)
        self.config = configs.default
        self.actor = None
        self.critic = None
        self.resume(self.db)
        self.last_active = SharedTimestamp(self.r)
        self.start_heartbeat(5, self.heartbeat, redis=self.r)

        logger.info('Init Complete')
Exemplo n.º 9
0
class Coordinator(Server):
    def __init__(self,
                 redis_host='localhost',
                 redis_port=6379,
                 redis_db=0,
                 redis_password=None,
                 db_host='localhost',
                 db_port=5432,
                 db_name='policy_db',
                 db_user='******',
                 db_password=None):
        super().__init__(redis_host, redis_port, redis_db, redis_password)
        self.db = PolicyDB(db_host=db_host,
                           db_port=db_port,
                           db_name=db_name,
                           db_user=db_user,
                           db_password=db_password)
        self.exp_buffer = Db(host=redis_host,
                             port=redis_port,
                             db=redis_db,
                             password=redis_password)

        self.db_host = db_host
        self.db_port = db_port
        self.db_name = db_name
        self.db_user = db_user
        self.db_password = db_password

        self.handler.register(StartMessage, self.handle_start)
        self.handler.register(StopMessage, self.handle_stop)
        self.handler.register(EpisodeMessage, self.handle_episode)
        self.handler.register(TrainCompleteMessage, self.handle_train_complete)
        self.handler.register(ConfigUpdateMessage, self.handle_config_update)
        self.config = configs.default
        self.actor = None
        self.critic = None
        self.resume(self.db)
        self.last_active = SharedTimestamp(self.r)
        self.start_heartbeat(5, self.heartbeat, redis=self.r)

        logger.info('Init Complete')

    def set_state(self, state):
        self.r.set('co-ordinator-state', state.encode())

    @property
    def state(self):
        return self.r.get('co-ordinator-state').decode()

    def resume(self, db):
        record = db.latest_run()
        if record is not None:
            logger.info(f'resuming run {record.run} state {record.run_state}')
            self.set_state(record.run_state)
            self.config = record.config_b

            self.exp_buffer.clear_rollouts()
            rollout = self.exp_buffer.create_rollout(
                self.config.data.coder.construct())
            self.actor = self.config.actor.construct()
            self.actor.load_state_dict(record.actor)
            self.critic = self.config.critic.construct()
            self.critic.load_state_dict(record.critic)

            if self.state != STOPPED:
                self.set_state(GATHERING)
                RolloutMessage(
                    self.id, self.config.run_id, rollout.id, self.actor,
                    self.config,
                    self.config.gatherer.episodes_per_gatherer).send(self.r)

    def handle_start(self, msg):
        logger.info(f'started run {msg.config.run_id}')
        self.set_state(GATHERING)
        self.config = msg.config
        self.actor = self.config.actor.construct()
        self.critic = self.config.critic.construct()

        # setup monitoring for the new run
        self.db.write_policy(self.config.run_id, self.state,
                             self.actor.state_dict(), self.critic.state_dict(),
                             {'ave_reward_episode': 0.0}, self.config)
        StartMonitoringMessage(self.id, msg.config.run_id).send(self.r)

        # init the experience buffer and start the actors
        self.exp_buffer.clear_rollouts()
        rollout = self.exp_buffer.create_rollout(
            self.config.data.coder.construct())
        RolloutMessage(self.id, msg.config.run_id, rollout.id, self.actor,
                       self.config,
                       self.config.gatherer.episodes_per_gatherer).send(self.r)

    def handle_episode(self, msg):
        if not self.state == STOPPED:
            rollout = self.exp_buffer.latest_rollout(
                self.config.data.coder.construct())
            steps = len(rollout)

            if steps >= self.config.gatherer.num_steps_per_rollout and not self.state == TRAINING:
                # experience buffer is full, start training
                logger.info(f'Starting training with {steps} steps')
                rollout.finalize()
                self.set_state(TRAINING)

                # capture statistics from the exp buffer
                total_reward = 0
                for episode in rollout:
                    total_reward += episode.total_reward()
                ave_reward = total_reward / rollout.num_episodes()
                stats = {'ave_reward_episode': ave_reward}

                # and save the policy
                self.db.write_policy(self.config.run_id, self.state,
                                     self.actor.state_dict(),
                                     self.critic.state_dict(), stats,
                                     self.config)
                self.db.update_reservoir(
                    self.config.run_id,
                    self.config.gatherer.policy_reservoir_depth)
                self.db.update_best(self.config.run_id,
                                    self.config.gatherer.policy_top_depth)
                self.db.prune(self.config.run_id)

                logger.info('Sending Training Complete message')
                TrainMessage(self.id, self.actor, self.critic,
                             self.config).send(self.r)

        self.last_active.update()

    def handle_train_complete(self, msg):
        logger.info('Training completed')

        self.actor = msg.actor
        self.critic = msg.critic
        ResetMessage(self.id).send(self.r)

        self.exp_buffer.clear_rollouts()
        rollout = self.exp_buffer.create_rollout(
            self.config.data.coder.construct())

        if not self.state == STOPPED:
            RolloutMessage(self.id, self.config.run_id, rollout.id, self.actor,
                           self.config,
                           self.config.gatherer.episodes_per_gatherer).send(
                               self.r)
            self.set_state(GATHERING)

        self.last_active.update()

    def handle_stop(self, msg):
        logger.debug('Got STOP message')
        self.set_state(STOPPED)
        self.db.set_state_latest(STOPPED)

    def handle_config_update(self, msg):
        logger.info(f'Got config update')

        run = self.db.latest_run()
        record = self.db.best(run.run).get()
        self.config = msg.config

        self.exp_buffer.clear_rollouts()
        rollout = self.exp_buffer.create_rollout(
            self.config.data.coder.construct())

        self.actor = self.config.actor.construct()
        self.actor.load_state_dict(record.actor)

        self.critic = self.config.critic.construct()
        self.critic.load_state_dict(record.critic)

        self.set_state(GATHERING)

        RolloutMessage(self.id, self.config.run_id, rollout.id, self.actor,
                       self.config,
                       self.config.gatherer.episodes_per_gatherer).send(self.r)

    def heartbeat(self, redis):
        last_active = SharedTimestamp(redis)
        state = redis.get('co-ordinator-state').decode()
        time_inactive = time.time() - last_active.ts
        # database connection must be formed inside the thread
        db = PolicyDB(db_host=self.db_host,
                      db_port=self.db_port,
                      db_name=self.db_name,
                      db_user=self.db_user,
                      db_password=self.db_password)
        logger.debug(
            f'Heartbeat state : {state},  time_inactive: {time_inactive}, timeout: {self.config.timeout}'
        )
        if state == GATHERING or state == TRAINING:
            if time_inactive > self.config.timeout:

                logging.info(
                    f'Heartbeat: inactive for {time_inactive} in {self.state} state.  Attempting resume'
                )
                self.resume(db)
Exemplo n.º 10
0
    parser.add_argument("-pa", "--postgres-password", help='password of postgres server', dest='postgres_password',
                        default='password')

    args = parser.parse_args()

    config_list = {
        'CartPole-v0': configs.Discrete('CartPole-v0'),
        'LunarLander-v2': configs.Discrete('LunarLander-v2'),
        'Acrobot-v1': configs.Discrete('Acrobot-v1'),
        'MountainCar-v0': configs.Discrete('MountainCar-v0'),
        'HalfCheetah-v1': configs.Continuous('RoboschoolHalfCheetah-v1'),
        'Hopper-v0': configs.Continuous('RoboschoolHopper-v1')
    }

    r = Redis(host=args.redis_host, port=args.redis_port, password=args.redis_password)
    exp_buffer = Db(host=args.redis_host, port=args.redis_port, password=args.redis_password)
    policy_db = PolicyDB(args.postgres_host, args.postgres_password, args.postgres_user, args.postgres_db,
                         args.postgres_port)

    gui_uuid = uuid.uuid4()

    services = MicroServiceBuffer()
    PingMessage(gui_uuid).send(r)

    gatherers = {}
    gatherers_progress = {}
    gatherers_progress_epi = {}
    next_free_slot = 0

    handler = MessageHandler(r, 'rollout')
    handler.register(EpisodeMessage, episode)