Exemplo n.º 1
0
 def master_runner_initialize(self,
                              agent,
                              bootstrap_value=False,
                              traj_info_kwargs=None):
     # Construct an example of each kind of data that needs to be stored.
     env = self.EnvCls(**self.env_kwargs)
     agent.initialize(
         env.spaces,
         share_memory=True)  # Actual agent initialization, keep.
     samples_pyt, samples_np, examples = build_samples_buffer(
         agent,
         env,
         self.batch_spec,
         bootstrap_value,
         agent_shared=True,
         env_shared=True,
         subprocess=False)  # Would like subprocess=True, but might hang?
     _, samples_np2, _ = build_samples_buffer(agent,
                                              env,
                                              self.batch_spec,
                                              bootstrap_value,
                                              agent_shared=True,
                                              env_shared=True,
                                              subprocess=False)
     env.close()
     del env
     if traj_info_kwargs:
         for k, v in traj_info_kwargs.items():
             setattr(self.TrajInfoCls, "_" + k, v)
     self.double_buffer = double_buffer = (samples_np, samples_np2)
     self.examples = examples
     return double_buffer, examples
Exemplo n.º 2
0
    def initialize(self,
                   agent,
                   affinity=None,
                   seed=None,
                   bootstrap_value=False,
                   traj_info_kwargs=None):
        envs = [
            self.EnvCls(**self.env_kwargs) for _ in range(self.batch_spec.B)
        ]
        agent.initialize(envs[0].spaces, share_memory=False)
        samples_pyt, samples_np, examples = build_samples_buffer(
            agent,
            envs[0],
            self.batch_spec,
            bootstrap_value,
            agent_shared=False,
            env_shared=False,
            subprocess=False)
        if traj_info_kwargs:
            for k, v in traj_info_kwargs.items():
                setattr(self.TrajInfoCls, "_" + k, v)  # Avoid passing at init.
        collector = self.CollectorCls(
            rank=0,
            envs=envs,
            samples_np=samples_np,
            batch_T=self.batch_spec.T,
            TrajInfoCls=self.TrajInfoCls,
            agent=agent,
        )
        if self.eval_n_envs > 0:  # May do evaluation.
            eval_envs = [
                self.EnvCls(**self.eval_env_kwargs)
                for _ in range(self.eval_n_envs)
            ]
            eval_CollectorCls = self.eval_CollectorCls or SerialEvalCollector
            self.eval_collector = eval_CollectorCls(
                envs=eval_envs,
                agent=agent,
                TrajInfoCls=self.TrajInfoCls,
                max_T=self.eval_max_steps // self.eval_n_envs,
                max_trajectories=self.eval_max_trajectories,
            )

        agent_inputs, traj_infos = collector.start_envs(
            self.max_decorrelation_steps)
        collector.start_agent()

        self.agent = agent
        self.samples_pyt = samples_pyt
        self.samples_np = samples_np
        self.collector = collector
        self.agent_inputs = agent_inputs
        self.traj_infos = traj_infos
        logger.log("Serial Sampler initialized.")
        return examples
Exemplo n.º 3
0
    def initialize(self,
                   agent,
                   affinity,
                   seed,
                   bootstrap_value=False,
                   traj_info_kwargs=None):
        n_parallel = len(affinity["workers_cpus"])
        n_envs_list = [self.batch_spec.B // n_parallel] * n_parallel
        if not self.batch_spec.B % n_parallel == 0:
            logger.log(
                "WARNING: unequal number of envs per process, from "
                f"batch_B {self.batch_spec.B} and n_parallel {n_parallel} "
                "(possibly suboptimal speed).")
            for b in range(self.batch_spec.B % n_parallel):
                n_envs_list[b] += 1

        # Construct an example of each kind of data that needs to be stored.
        env = self.EnvCls(**self.env_kwargs)
        agent.initialize(env.spaces,
                         share_memory=True)  # Actual agent initialization.
        samples_pyt, samples_np, examples = build_samples_buffer(
            agent,
            env,
            self.batch_spec,
            bootstrap_value,
            agent_shared=True,
            env_shared=True,
            subprocess=True)  # TODO: subprocess=True fix!!
        env.close()
        del env

        ctrl, traj_infos_queue, sync = build_par_objs(n_parallel)
        if traj_info_kwargs:
            for k, v in traj_info_kwargs.items():
                setattr(self.TrajInfoCls, "_" + k, v)  # Avoid passing at init.

        if self.eval_n_envs > 0:
            # assert self.eval_n_envs % n_parallel == 0
            eval_n_envs_per = max(1, self.eval_n_envs // n_parallel)
            eval_n_envs = eval_n_envs_per * n_parallel
            logger.log(f"Total parallel evaluation envs: {eval_n_envs}")
            self.eval_max_T = eval_max_T = self.eval_max_steps // eval_n_envs
        else:
            eval_n_envs_per = 0
            eval_max_T = None

        common_kwargs = dict(
            EnvCls=self.EnvCls,
            env_kwargs=self.env_kwargs,
            agent=agent,
            batch_T=self.batch_spec.T,
            CollectorCls=self.CollectorCls,
            TrajInfoCls=self.TrajInfoCls,
            traj_infos_queue=traj_infos_queue,
            ctrl=ctrl,
            max_decorrelation_steps=self.max_decorrelation_steps,
            torch_threads=affinity.get("worker_torch_threads", None),
            eval_n_envs=eval_n_envs_per,
            eval_CollectorCls=self.eval_CollectorCls or EvalCollector,
            eval_env_kwargs=self.eval_env_kwargs,
            eval_max_T=eval_max_T,
        )

        workers_kwargs = assemble_workers_kwargs(affinity, seed, samples_np,
                                                 n_envs_list, sync)

        workers = [
            mp.Process(target=sampling_process,
                       kwargs=dict(common_kwargs=common_kwargs,
                                   worker_kwargs=w_kwargs))
            for w_kwargs in workers_kwargs
        ]
        for w in workers:
            w.start()

        self.workers = workers
        self.ctrl = ctrl
        self.traj_infos_queue = traj_infos_queue
        self.sync = sync
        self.samples_pyt = samples_pyt
        self.samples_np = samples_np
        self.agent = agent

        self.ctrl.barrier_out.wait()  # Wait for workers to decorrelate envs.
        return examples  # e.g. In case useful to build replay buffer.