예제 #1
0
    def evaluate_agent(self, itr):
        """Signal worker processes to perform agent evaluation.  If a max
        number of evaluation trajectories was specified, keep watch over the
        number of trajectories finished and signal an early end if the limit
        is reached.  Return a list of trajectory-info objects from the
        completed episodes.
        """
        self.ctrl.itr.value = itr
        self.ctrl.do_eval.value = True
        self.sync.stop_eval.value = False
        self.ctrl.barrier_in.wait()
        traj_infos = list()
        if self.eval_max_trajectories is not None:
            while True:
                time.sleep(EVAL_TRAJ_CHECK)
                traj_infos.extend(
                    drain_queue(self.eval_traj_infos_queue,
                                guard_sentinel=True))
                if len(traj_infos) >= self.eval_max_trajectories:
                    self.sync.stop_eval.value = True
                    logger.log("Evaluation reached max num trajectories "
                               f"({self.eval_max_trajectories}).")
                    break  # Stop possibly before workers reach max_T.
                if self.ctrl.barrier_out.parties - self.ctrl.barrier_out.n_waiting == 1:
                    logger.log("Evaluation reached max num time steps "
                               f"({self.eval_max_T}).")
                    break  # Workers reached max_T.
        self.ctrl.barrier_out.wait()
        traj_infos.extend(
            drain_queue(self.eval_traj_infos_queue, n_sentinel=self.n_worker))
        self.ctrl.do_eval.value = False

        return traj_infos
예제 #2
0
    def evaluate_agent(self, itr):
        """
        评估模型。

        :param itr: 第几次迭代。
        :return: trajectory的统计信息。
        """
        self.ctrl.itr.value = itr
        self.ctrl.do_eval.value = True
        self.sync.stop_eval.value = False
        self.ctrl.barrier_in.wait()
        traj_infos = list()
        if self.eval_max_trajectories is not None:
            while True:
                time.sleep(EVAL_TRAJ_CHECK)
                traj_infos.extend(
                    drain_queue(self.eval_traj_infos_queue,
                                guard_sentinel=True))
                if len(traj_infos) >= self.eval_max_trajectories:
                    self.sync.stop_eval.value = True
                    logger.log("Evaluation reached max num trajectories "
                               f"({self.eval_max_trajectories}).")
                    break  # Stop possibly before workers reach max_T.
                if self.ctrl.barrier_out.parties - self.ctrl.barrier_out.n_waiting == 1:
                    logger.log("Evaluation reached max num time steps "
                               f"({self.eval_max_T}).")
                    break  # Workers reached max_T.
        self.ctrl.barrier_out.wait()
        traj_infos.extend(
            drain_queue(self.eval_traj_infos_queue, n_sentinel=self.n_worker))
        self.ctrl.do_eval.value = False
        return traj_infos
예제 #3
0
    def train(self):
        """
        Run the optimizer in a loop.  Check whether enough new samples have
        been generated, and throttle down if necessary at each iteration.  Log
        at an interval in the number of sampler iterations, not optimizer
        iterations.
        """
        throttle_itr, delta_throttle_itr = self.startup()
        throttle_time = 0.
        sampler_itr = itr = 0
        if self._eval:
            while self.ctrl.sampler_itr.value < 1:  # Sampler does eval first.
                time.sleep(THROTTLE_WAIT)
            traj_infos = drain_queue(self.traj_infos_queue, n_sentinel=1)
            self.store_diagnostics(0, 0, traj_infos, ())
            self.log_diagnostics(0, 0, 0)

        log_counter = 0
        while True:  # Run until sampler hits n_steps and sets ctrl.quit=True.
            logger.set_iteration(itr)
            with logger.prefix(f"opt_itr #{itr} "):
                while self.ctrl.sampler_itr.value < throttle_itr:
                    if self.ctrl.quit.value:
                        break
                    time.sleep(THROTTLE_WAIT)
                    throttle_time += THROTTLE_WAIT
                if self.ctrl.quit.value:
                    break
                if self.ctrl.opt_throttle is not None:
                    self.ctrl.opt_throttle.wait()
                throttle_itr += delta_throttle_itr
                opt_info = self.algo.optimize_agent(
                    itr, sampler_itr=self.ctrl.sampler_itr.value)
                self.agent.send_shared_memory()  # To sampler.
                sampler_itr = self.ctrl.sampler_itr.value
                traj_infos = (list() if self._eval else drain_queue(
                    self.traj_infos_queue))
                self.store_diagnostics(itr, sampler_itr, traj_infos, opt_info)
                if (sampler_itr // self.log_interval_itrs > log_counter):
                    if self._eval:
                        with self.ctrl.sampler_itr.get_lock():
                            traj_infos = drain_queue(self.traj_infos_queue,
                                                     n_sentinel=1)
                        self.store_diagnostics(itr, sampler_itr, traj_infos,
                                               ())
                    self.log_diagnostics(itr, sampler_itr, throttle_time)
                    log_counter += 1
                    throttle_time = 0.
            itr += 1
        # Final log:
        sampler_itr = self.ctrl.sampler_itr.value
        traj_infos = drain_queue(self.traj_infos_queue)
        if traj_infos or not self._eval:
            self.store_diagnostics(itr, sampler_itr, traj_infos, ())
            self.log_diagnostics(itr, sampler_itr, throttle_time)
        self.shutdown()
예제 #4
0
 def obtain_samples(self, itr):
     """Signal worker processes to collect samples, and wait until they
     finish. Workers will write directly to the pre-allocated samples
     buffer, which this method returns.  Trajectory-info objects from
     completed trajectories are retrieved from workers through a parallel
     queue object and are also returned.
     """
     self.ctrl.itr.value = itr
     self.ctrl.barrier_in.wait()
     # Workers step environments and sample actions here.
     self.ctrl.barrier_out.wait()
     player_traj_infos = drain_queue(self.player_traj_infos_queue)
     observer_traj_infos = drain_queue(self.observer_traj_infos_queue)
     return self.player_samples_pyt, player_traj_infos, self.observer_samples_pyt, observer_traj_infos
예제 #5
0
    def serve_actions_evaluation(self, itr):
        obs_ready, act_ready = self.sync.obs_ready, self.sync.act_ready
        obs_ready_pair = self.obs_ready_pair
        act_ready_pair = self.act_ready_pair
        step_np_pair = self.eval_step_buffer_np_pair
        agent_inputs_pair = self.eval_agent_inputs_pair
        traj_infos = list()
        self.agent.reset()
        stop = False

        for t in range(self.eval_max_T):
            if t % EVAL_TRAJ_CHECK == 0:  # (While workers stepping.)
                traj_infos.extend(
                    drain_queue(self.eval_traj_infos_queue, guard_sentinel=True)
                )
            for alt in range(2):
                step_h = step_np_pair[alt]
                for b in obs_ready_pair[alt]:
                    b.acquire()
                    # assert not b.acquire(block=False)  # Debug check.
                for b_reset in np.where(step_h.done)[0]:
                    step_h.action[b_reset] = 0  # Null prev_action.
                    step_h.reward[b_reset] = 0  # Null prev_reward.
                    self.agent.reset_one(idx=b_reset)
                action, agent_info = self.agent.step(*agent_inputs_pair[alt])
                step_h.action[:] = action
                step_h.agent_info[:] = agent_info
                if (
                    self.eval_max_trajectories is not None
                    and t % EVAL_TRAJ_CHECK == 0
                    and alt == 0
                ):
                    if len(traj_infos) >= self.eval_max_trajectories:
                        for b in obs_ready_pair[1 - alt]:
                            b.acquire()  # Now all workers waiting.
                        self.sync.stop_eval.value = stop = True
                        for w in act_ready[alt]:
                            w.release()
                        break
                for w in act_ready_pair[alt]:
                    # assert not w.acquire(block=False)  # Debug check.
                    w.release()
            if stop:
                logger.log(
                    "Evaluation reached max num trajectories "
                    f"({self.eval_max_trajectories})."
                )
                break

        # TODO: check exit logic for/while ..?
        if not stop:
            logger.log("Evaluation reached max num time steps " f"({self.eval_max_T}).")

        for b in obs_ready:
            b.acquire()  # Workers always do extra release; drain it.
            assert not b.acquire(block=False)  # Debug check.
        for w in act_ready:
            assert not w.acquire(block=False)  # Debug check.

        return traj_infos
예제 #6
0
파일: base.py 프로젝트: wilson1yan/rlpyt
 def obtain_samples(self, itr):
     self.ctrl.itr.value = itr
     self.ctrl.barrier_in.wait()
     # Workers step environments and sample actions here.
     self.ctrl.barrier_out.wait()
     traj_infos = drain_queue(self.traj_infos_queue)
     return self.samples_pyt, traj_infos
 def obtain_samples(self, itr):
     # self.samples_np[:] = 0  # Reset all batch sample values (optional).
     self.agent.sample_mode(itr)
     self.ctrl.barrier_in.wait()
     self.serve_actions(itr)  # Worker step environments here.
     self.ctrl.barrier_out.wait()
     traj_infos = drain_queue(self.traj_infos_queue)
     return self.samples_pyt, traj_infos
예제 #8
0
 def train(self):
     throttle_itr, delta_throttle_itr = self.startup()
     throttle_time = 0.
     sampler_itr = itr = 0
     if self._eval:
         while self.ctrl.sampler_itr.value < 1:  # Sampler does eval first.
             time.sleep(THROTTLE_WAIT)
         traj_infos = drain_queue(self.traj_infos_queue, n_sentinel=1)
         self.store_diagnostics(0, 0, traj_infos, ())
         self.log_diagnostics(0, 0, 0)
     log_counter = 0
     while True:  # Run until sampler hits n_steps and sets ctrl.quit=True.
         with logger.prefix(f"opt_itr #{itr} "):
             while self.ctrl.sampler_itr.value < throttle_itr:
                 if self.ctrl.quit.value:
                     break
                 time.sleep(THROTTLE_WAIT)
                 throttle_time += THROTTLE_WAIT
             if self.ctrl.quit.value:
                 break
             if self.ctrl.opt_throttle is not None:
                 self.ctrl.opt_throttle.wait()
             throttle_itr += delta_throttle_itr
             opt_info = self.algo.optimize_agent(itr,
                 sampler_itr=self.ctrl.sampler_itr.value)
             self.agent.send_shared_memory()  # To sampler.
             sampler_itr = self.ctrl.sampler_itr.value
             traj_infos = (list() if self._eval else
                 drain_queue(self.traj_infos_queue))
             self.store_diagnostics(itr, sampler_itr, traj_infos, opt_info)
             if (sampler_itr // self.log_interval_itrs > log_counter):
                 if self._eval:
                     with self.ctrl.sampler_itr.get_lock():
                         traj_infos = drain_queue(self.traj_infos_queue, n_sentinel=1)
                     self.store_diagnostics(itr, sampler_itr, traj_infos, ())
                 self.log_diagnostics(itr, sampler_itr, throttle_time)
                 log_counter += 1
                 throttle_time = 0.
         itr += 1
     # Final log:
     sampler_itr = self.ctrl.sampler_itr.value
     traj_infos = drain_queue(self.traj_infos_queue)
     if traj_infos or not self._eval:
         self.store_diagnostics(itr, sampler_itr, traj_infos, ())
         self.log_diagnostics(itr, sampler_itr, throttle_time)
     self.shutdown()
 def obtain_samples(self, itr, db_idx):
     self.ctrl.itr.value = itr
     self.sync.db_idx.value = db_idx  # Double buffer index.
     self.ctrl.barrier_in.wait()
     # Workers step environments and sample actions here.
     self.ctrl.barrier_out.wait()
     traj_infos = drain_queue(self.traj_infos_queue)
     return traj_infos
예제 #10
0
 def init_obs_norm(self):
     """
     Initializes observation normalization parameters in intrinsic bonus model.
     Agent base network is not stepped, rather the action space is sampled randomly
     to exercise the bonus model obs norm module. This will run for at least as many
     steps specified in self.obs_norm_steps.
     """
     logger.log(
         f"Sampler initializing bonus model observation normalization, steps: {self.obs_norm_steps}"
     )
     action_space = self.EnvCls(**self.env_kwargs).action_space
     world_batch_size = self.batch_size * self.world_size
     from math import ceil
     for _ in range(ceil(self.obs_norm_steps / world_batch_size)):
         self.ctrl.barrier_in.wait()
         self.run_obs_norm(action_space)
         self.ctrl.barrier_out.wait()
         drain_queue(self.traj_infos_queue)
 def evaluate_agent(self, itr):
     self.ctrl.do_eval.value = True
     self.sync.stop_eval.value = False
     self.agent.eval_mode(itr)
     self.ctrl.barrier_in.wait()
     traj_infos = self.serve_actions_evaluation(itr)
     self.ctrl.barrier_out.wait()
     traj_infos.extend(
         drain_queue(self.eval_traj_infos_queue, n_sentinel=self.n_worker)
     )  # Block until all finish submitting.
     self.ctrl.do_eval.value = False
     return traj_infos
예제 #12
0
 def obtain_samples(self, itr):
     """Signals worker to begin environment step execution loop, and drops
     into ``serve_actions()`` method to provide actions to workers based on
     the new observations at each step.
     """
     # self.samples_np[:] = 0  # Reset all batch sample values (optional).
     self.agent.sample_mode(itr)
     self.ctrl.barrier_in.wait()
     self.serve_actions(itr)  # Worker step environments here.
     self.ctrl.barrier_out.wait()
     traj_infos = drain_queue(self.traj_infos_queue)
     return self.samples_pyt, traj_infos
예제 #13
0
    def obtain_samples(self, itr):
        """
        采样一批数据。

        :param itr: 第几次迭代。
        :return: TODO
        """
        self.ctrl.itr.value = itr
        self.ctrl.barrier_in.wait()
        # Workers step environments and sample actions here.
        self.ctrl.barrier_out.wait()
        traj_infos = drain_queue(self.traj_infos_queue)
        return self.samples_pyt, traj_infos
예제 #14
0
파일: base.py 프로젝트: keirp/glamor
 def obtain_samples(self, itr, db_idx):
     """Communicates to workers which batch buffer to use, and signals them
     to start collection.  Waits until workers finish, and then retrieves
     completed trajectory-info objects from the workers and returns them in
     a list.
     """
     self.ctrl.itr.value = itr
     self.sync.db_idx.value = db_idx  # Double buffer index.
     self.ctrl.barrier_in.wait()
     # Workers step environments and sample actions here.
     self.ctrl.barrier_out.wait()
     traj_infos = drain_queue(self.traj_infos_queue)
     return traj_infos
예제 #15
0
    def serve_actions_evaluation(self, itr):
        """Similar to ``serve_actions()``.  If a maximum number of eval trajectories
        was specified, keeps track of the number completed and terminates evaluation
        if the max is reached.  Returns a list of completed trajectory-info objects.
        """
        obs_ready, act_ready = self.sync.obs_ready, self.sync.act_ready
        step_np, step_pyt = self.eval_step_buffer_np, self.eval_step_buffer_pyt
        traj_infos = list()
        self.agent.reset()
        agent_inputs = AgentInputs(
            step_pyt.observation, step_pyt.action, step_pyt.reward
        )  # Fixed buffer objects.

        for t in range(self.eval_max_T):
            if t % EVAL_TRAJ_CHECK == 0:  # (While workers stepping.)
                traj_infos.extend(
                    drain_queue(self.eval_traj_infos_queue, guard_sentinel=True)
                )
            for b in obs_ready:
                b.acquire()
                # assert not b.acquire(block=False)  # Debug check.
            for b_reset in np.where(step_np.done)[0]:
                step_np.action[b_reset] = 0  # Null prev_action.
                step_np.reward[b_reset] = 0  # Null prev_reward.
                self.agent.reset_one(idx=b_reset)
            action, agent_info = self.agent.step(*agent_inputs)
            step_np.action[:] = action
            step_np.agent_info[:] = agent_info
            if self.eval_max_trajectories is not None and t % EVAL_TRAJ_CHECK == 0:
                self.sync.stop_eval.value = (
                    len(traj_infos) >= self.eval_max_trajectories
                )
            for w in act_ready:
                # assert not w.acquire(block=False)  # Debug check.
                w.release()
            if self.sync.stop_eval.value:
                logger.log(
                    "Evaluation reach max num trajectories "
                    f"({self.eval_max_trajectories})."
                )
                break
        if t == self.eval_max_T - 1 and self.eval_max_trajectories is not None:
            logger.log("Evaluation reached max num time steps " f"({self.eval_max_T}).")
        for b in obs_ready:
            b.acquire()  # Workers always do extra release; drain it.
            assert not b.acquire(block=False)  # Debug check.
        for w in act_ready:
            assert not w.acquire(block=False)  # Debug check.

        return traj_infos
예제 #16
0
 def evaluate_agent(self, itr):
     """Signals workers to begin agent evaluation loop, and drops into
     ``serve_actions_evaluation()`` to provide actions to workers at each
     step.
     """
     self.ctrl.do_eval.value = True
     self.sync.stop_eval.value = False
     self.agent.eval_mode(itr)
     self.ctrl.barrier_in.wait()
     traj_infos = self.serve_actions_evaluation(itr)
     self.ctrl.barrier_out.wait()
     traj_infos.extend(drain_queue(self.eval_traj_infos_queue,
         n_sentinel=self.n_worker))  # Block until all finish submitting.
     self.ctrl.do_eval.value = False
     return traj_infos
예제 #17
0
 def store_diagnostics(self, itr, traj_infos, opt_info):
     traj_infos.extend(drain_queue(self.par.traj_infos_queue))
     super().store_diagnostics(itr, traj_infos, opt_info)
예제 #18
0
    def serve_actions_evaluation(self, itr):
        obs_ready, act_ready = self.sync.obs_ready, self.sync.act_ready
        obs_ready_pair = self.obs_ready_pair
        act_ready_pair = self.act_ready_pair
        step_np, step_np_pair = self.eval_step_buffer_np, self.eval_step_buffer_np_pair
        agent_inputs = self.eval_agent_inputs
        agent_inputs_pair = self.eval_agent_inputs_pair
        traj_infos = list()
        self.agent.reset()
        step_np.action[:] = 0  # Null prev_action.
        step_np.reward[:] = 0  # Null prev_reward.

        # First step of both.
        alt = 0
        step_h = step_np_pair[alt]
        for b in obs_ready_pair[alt]:
            b.acquire()
            # assert not b.acquire(block=False)  # Debug check.
        action, agent_info = self.agent.step(*agent_inputs_pair[alt])
        step_h.action[:] = action
        step_h.agent_info[:] = agent_info
        alt = 1
        step_h = step_np_pair[alt]
        for b in obs_ready_pair[alt]:
            b.acquire()
            # assert not b.acquire(block=False)  # Debug check.
        for w in act_ready_pair[1 - alt]:
            # assert not w.acquire(block=False)  # Debug check.
            w.release()
        action, agent_info = self.agent.step(*agent_inputs_pair[alt])
        step_h.action[:] = action
        step_h.agent_info[:] = agent_info

        for t in range(1, self.eval_max_T):
            if t % EVAL_TRAJ_CHECK == 0:  # (While workers stepping.)
                traj_infos.extend(
                    drain_queue(self.eval_traj_infos_queue,
                                guard_sentinel=True))
            for alt in range(2):
                step_h = step_np_pair[alt]
                for b in obs_ready_pair[alt]:
                    b.acquire()
                    # assert not b.acquire(block=False)  # Debug check.
                for w in act_ready_pair[1 - alt]:
                    # assert not w.acquire(block=False)  # Debug check.
                    w.release()
                for b_reset in np.where(step_h.done)[0]:
                    step_h.action[b_reset] = 0  # Null prev_action.
                    step_h.reward[b_reset] = 0  # Null prev_reward.
                    self.agent.reset_one(idx=b_reset)
                action, agent_info = self.agent.step(*agent_inputs_pair[alt])
                step_h.action[:] = action
                step_h.agent_info[:] = agent_info
            if self.eval_max_trajectories is not None and t % EVAL_TRAJ_CHECK == 0:
                self.sync.stop_eval.value = len(
                    traj_infos) >= self.eval_max_trajectories
            if self.sync.stop_eval.value:
                for w in act_ready_pair[1 - alt]:  # Other released past loop.
                    # assert not w.acquire(block=False)  # Debug check.
                    w.release()
                logger.log("Evaluation reached max num trajectories "
                           f"({self.eval_max_trajectories}).")
                break

        # TODO: check logic when traj limit hits at natural end of loop?
        for w in act_ready_pair[alt]:
            # assert not w.acquire(block=False)  # Debug check.
            w.release()
        if t == self.eval_max_T - 1 and self.eval_max_trajectories is not None:
            logger.log("Evaluation reached max num time steps "
                       f"({self.eval_max_T}).")

        for b in obs_ready:
            b.acquire()  # Workers always do extra release; drain it.
            # assert not b.acquire(block=False)  # Debug check.

        return traj_infos