Exemple #1
0
def sampling_process(common_kwargs, worker_kwargs):
    """Arguments fed from the Sampler class in master process."""
    c, w = AttrDict(**common_kwargs), AttrDict(**worker_kwargs)
    initialize_worker(w.rank, w.seed, w.cpus, c.torch_threads)
    envs = [c.EnvCls(**c.env_kwargs) for _ in range(w.n_envs)]
    collector = c.CollectorCls(
        rank=w.rank,
        envs=envs,
        samples_np=w.samples_np,
        batch_T=c.batch_T,
        TrajInfoCls=c.TrajInfoCls,
        agent=c.get("agent", None),  # Optional depending on parallel setup.
        sync=w.get("sync", None),
        step_buffer_np=w.get("step_buffer_np", None),
        global_B=c.get("global_B", 1),
        env_ranks=w.get("env_ranks", None),
    )

    agent_inputs, traj_infos = collector.start_envs(c.max_decorrelation_steps)
    collector.start_agent()

    if c.get("eval_n_envs", 0) > 0:
        eval_envs = [c.EnvCls(**c.env_kwargs) for _ in range(c.eval_n_envs)]
        eval_collector = c.eval_CollectorCls(
            rank=w.rank,
            envs=eval_envs,
            TrajInfoCls=c.TrajInfoCls,
            traj_infos_queue=c.eval_traj_infos_queue,
            max_T=c.eval_max_T,
            agent=c.get("agent", None),
            sync=w.get("sync", None),
            step_buffer_np=w.get("eval_step_buffer_np", None),
            )
    else:
        eval_envs = list()

    ctrl = c.ctrl
    ctrl.barrier_out.wait()
    while True:
        collector.reset_if_needed(agent_inputs)  # Outside barrier?
        ctrl.barrier_in.wait()
        if ctrl.quit.value:
            break
        # if ctrl.do_eval.value:
        #     print("PASS")
        #     eval_collector.collect_evaluation(ctrl.itr.value)  # Traj_infos to queue inside.
        else:
            agent_inputs, traj_infos, completed_infos = collector.collect_batch(
                agent_inputs, traj_infos, ctrl.itr.value)
            for info in completed_infos:
                c.traj_infos_queue.put(info)
        ctrl.barrier_out.wait()

    for env in envs + eval_envs:
        env.close()
Exemple #2
0
def sampling_process(common_kwargs, worker_kwargs):
    """Target function used for forking parallel worker processes in the
    samplers. After ``initialize_worker()``, it creates the specified number
    of environment instances and gives them to the collector when
    instantiating it.  It then calls collector startup methods for
    environments and agent.  If applicable, instantiates evaluation
    environment instances and evaluation collector.

    Then enters infinite loop, waiting for signals from master to collect
    training samples or else run evaluation, until signaled to exit.
    """
    c, w = AttrDict(**common_kwargs), AttrDict(**worker_kwargs)
    initialize_worker(w.rank, w.seed, w.cpus, c.torch_threads)
    envs = [c.EnvCls(**c.env_kwargs) for _ in range(w.n_envs)]
    set_envs_seeds(envs, w.seed)

    collector = c.CollectorCls(
        rank=w.rank,
        envs=envs,
        samples_np=w.samples_np,
        batch_T=c.batch_T,
        TrajInfoCls=c.TrajInfoCls,
        agent=c.get("agent", None),  # Optional depending on parallel setup.
        sync=w.get("sync", None),
        step_buffer_np=w.get("step_buffer_np", None),
        global_B=c.get("global_B", 1),
        env_ranks=w.get("env_ranks", None),
    )
    agent_inputs, traj_infos = collector.start_envs(c.max_decorrelation_steps)
    collector.start_agent()

    if c.get("eval_n_envs", 0) > 0:
        eval_envs = [
            c.EnvCls(**c.eval_env_kwargs) for _ in range(c.eval_n_envs)
        ]
        set_envs_seeds(eval_envs, w.seed)
        eval_collector = c.eval_CollectorCls(
            rank=w.rank,
            envs=eval_envs,
            TrajInfoCls=c.TrajInfoCls,
            traj_infos_queue=c.eval_traj_infos_queue,
            max_T=c.eval_max_T,
            agent=c.get("agent", None),
            sync=w.get("sync", None),
            step_buffer_np=w.get("eval_step_buffer_np", None),
        )
    else:
        eval_envs = list()

    ctrl = c.ctrl
    ctrl.barrier_out.wait()
    while True:
        collector.reset_if_needed(agent_inputs)  # Outside barrier?
        ctrl.barrier_in.wait()
        if ctrl.quit.value:
            break
        if ctrl.do_eval.value:
            # Traj_infos to queue inside.
            eval_collector.collect_evaluation(ctrl.itr.value)
        else:
            agent_inputs, traj_infos, completed_infos = collector.collect_batch(
                agent_inputs, traj_infos, ctrl.itr.value)
            for info in completed_infos:
                c.traj_infos_queue.put(info)
        ctrl.barrier_out.wait()

    for env in envs + eval_envs:
        env.close()
Exemple #3
0
def sampling_process(common_kwargs, worker_kwargs):
    """
    Arguments fed from the Sampler class in master process.

    采样进程函数。

    :param common_kwargs: 各个worker通用的参数列表。
    :param worker_kwargs: 各个worker可能不同的参数列表。
    """
    c, w = AttrDict(**common_kwargs), AttrDict(**worker_kwargs)
    initialize_worker(w.rank, w.seed, w.cpus, c.torch_threads)
    # 初始化用于training的environment实例和collector实例
    envs = [c.EnvCls(**c.env_kwargs) for _ in range(w.n_envs)]
    collector = c.CollectorCls(
        rank=w.rank,
        envs=envs,
        samples_np=w.samples_np,
        batch_T=c.batch_T,
        TrajInfoCls=c.TrajInfoCls,
        agent=c.get("agent", None),  # Optional depending on parallel setup.
        sync=w.get("sync", None),
        step_buffer_np=w.get("step_buffer_np", None),
        global_B=c.get("global_B", 1),
        env_ranks=w.get("env_ranks", None),
    )
    agent_inputs, traj_infos = collector.start_envs(
        c.max_decorrelation_steps)  # 这里会做收集(采样)第一批数据的工作
    collector.start_agent()  # collector的初始化

    # 初始化用于evaluation的environment实例和collector实例
    if c.get("eval_n_envs", 0) > 0:
        eval_envs = [
            c.EnvCls(**c.eval_env_kwargs) for _ in range(c.eval_n_envs)
        ]
        eval_collector = c.eval_CollectorCls(
            rank=w.rank,
            envs=eval_envs,
            TrajInfoCls=c.TrajInfoCls,
            traj_infos_queue=c.eval_traj_infos_queue,
            max_T=c.eval_max_T,
            agent=c.get("agent", None),
            sync=w.get("sync", None),
            step_buffer_np=w.get("eval_step_buffer_np", None),
        )
    else:
        eval_envs = list()

    ctrl = c.ctrl  # 用于控制多个worker进程同时运行时能正确运作的控制器
    ctrl.barrier_out.wait(
    )  # 每个worker都有一个wait(),加上ParallelSamplerBase.initialize()中的一个wait(),刚好n_worker+1个
    while True:
        collector.reset_if_needed(agent_inputs)  # Outside barrier?
        ctrl.barrier_in.wait()
        if ctrl.quit.value:  # 在主进程中set了这个值为True时,所有worker进程会退出采样
            break
        if ctrl.do_eval.value:  # 在主进程的evaluate_agent()函数里set了这个值为True时,这里才会收集evaluation用的数据
            eval_collector.collect_evaluation(
                ctrl.itr.value)  # Traj_infos to queue inside.
        else:  # 不是做evaluation
            agent_inputs, traj_infos, completed_infos = collector.collect_batch(
                agent_inputs, traj_infos, ctrl.itr.value)
            for info in completed_infos:
                c.traj_infos_queue.put(info)  # 向所有worker进程共享的队列塞入当前worker的统计数据
        ctrl.barrier_out.wait()

    # 清理environment
    for env in envs + eval_envs:
        env.close()
Exemple #4
0
def _mux_sampler(common_kwargs, worker_kwargs):
    """Variant of `rlpyt.samplers.parallel.worker.sampling_process` that is
    able to supply different environment keyword arguments to each environment
    that makes up a batch."""
    c, w = AttrDict(**common_kwargs), AttrDict(**worker_kwargs)
    initialize_worker(w.rank, w.seed, w.cpus, c.torch_threads)
    # vvv CHANGED LINES vvv
    if isinstance(c.env_kwargs, (list, tuple)):
        env_ranks = w["env_ranks"]
        envs = [c.EnvCls(**c.env_kwargs[rank]) for rank in env_ranks]
    else:
        envs = [c.EnvCls(**c.env_kwargs) for _ in range(w.n_envs)]
    # ^^^ CHANGED LINES ^^^
    collector = c.CollectorCls(
        rank=w.rank,
        envs=envs,
        samples_np=w.samples_np,
        batch_T=c.batch_T,
        TrajInfoCls=c.TrajInfoCls,
        agent=c.get("agent", None),  # Optional depending on parallel setup.
        sync=w.get("sync", None),
        step_buffer_np=w.get("step_buffer_np", None),
        global_B=c.get("global_B", 1),
        env_ranks=w.get("env_ranks", None),
    )
    agent_inputs, traj_infos = collector.start_envs(c.max_decorrelation_steps)
    collector.start_agent()

    if c.get("eval_n_envs", 0) > 0:
        eval_envs = [
            c.EnvCls(**c.eval_env_kwargs) for _ in range(c.eval_n_envs)
        ]
        eval_collector = c.eval_CollectorCls(
            rank=w.rank,
            envs=eval_envs,
            TrajInfoCls=c.TrajInfoCls,
            traj_infos_queue=c.eval_traj_infos_queue,
            max_T=c.eval_max_T,
            agent=c.get("agent", None),
            sync=w.get("sync", None),
            step_buffer_np=w.get("eval_step_buffer_np", None),
        )
    else:
        eval_envs = list()

    ctrl = c.ctrl
    ctrl.barrier_out.wait()
    while True:
        collector.reset_if_needed(agent_inputs)  # Outside barrier?
        ctrl.barrier_in.wait()
        if ctrl.quit.value:
            break
        if ctrl.do_eval.value:
            eval_collector.collect_evaluation(
                ctrl.itr.value)  # Traj_infos to queue inside.
        else:
            (agent_inputs, traj_infos,
             completed_infos) = collector.collect_batch(
                 agent_inputs, traj_infos, ctrl.itr.value)
            for info in completed_infos:
                c.traj_infos_queue.put(info)
        ctrl.barrier_out.wait()

    for env in envs + eval_envs:
        env.close()
Exemple #5
0
def sampling_process(common_kwargs, worker_kwargs):
    """Target function used for forking parallel worker processes in the
    samplers. After ``initialize_worker()``, it creates the specified number
    of environment instances and gives them to the collector when
    instantiating it.  It then calls collector startup methods for
    environments and agent.  If applicable, instantiates evaluation
    environment instances and evaluation collector.

    Then enters infinite loop, waiting for signals from master to collect
    training samples or else run evaluation, until signaled to exit.
    """
    c, w = AttrDict(**common_kwargs), AttrDict(**worker_kwargs)
    initialize_worker(w.rank, w.seed, w.cpus, c.torch_threads)

    envs = [c.EnvCls(**c.env_kwargs) for _ in range(w.n_envs)]

    log_heatmaps = c.env_kwargs.get('log_heatmaps', None)

    if log_heatmaps is not None and log_heatmaps == True:
        for env in envs[1:]:
            env.log_heatmaps = False

    if c.record_freq > 0:
        if c.env_kwargs['game'] in ATARI_ENVS:
            envs[0].record_env = True
            os.makedirs(os.path.join(c.log_dir, 'videos/frames'))
        elif c.get(
                "eval_n_envs", 0
        ) == 0:  # only record workers if no evaluation processes are performed
            envs[0] = Monitor(envs[0],
                              c.log_dir + '/videos',
                              video_callable=lambda episode_id: episode_id % c.
                              record_freq == 0)

    set_envs_seeds(envs, w.seed)

    collector = c.CollectorCls(
        rank=w.rank,
        envs=envs,
        samples_np=w.samples_np,
        batch_T=c.batch_T,
        TrajInfoCls=c.TrajInfoCls,
        agent=c.get("agent", None),  # Optional depending on parallel setup.
        sync=w.get("sync", None),
        step_buffer_np=w.get("step_buffer_np", None),
        global_B=c.get("global_B", 1),
        env_ranks=w.get("env_ranks", None),
        no_extrinsic=c.no_extrinsic)
    agent_inputs, traj_infos = collector.start_envs(c.max_decorrelation_steps)
    collector.start_agent()

    if c.get("eval_n_envs", 0) > 0:
        eval_envs = [
            c.EnvCls(**c.eval_env_kwargs) for _ in range(c.eval_n_envs)
        ]
        if c.record_freq > 0:
            eval_envs[0] = Monitor(eval_envs[0],
                                   c.log_dir + '/videos',
                                   video_callable=lambda episode_id: episode_id
                                   % c.record_freq == 0)
        set_envs_seeds(eval_envs, w.seed)
        eval_collector = c.eval_CollectorCls(
            rank=w.rank,
            envs=eval_envs,
            TrajInfoCls=c.TrajInfoCls,
            traj_infos_queue=c.eval_traj_infos_queue,
            max_T=c.eval_max_T,
            agent=c.get("agent", None),
            sync=w.get("sync", None),
            step_buffer_np=w.get("eval_step_buffer_np", None),
        )
    else:
        eval_envs = list()

    ctrl = c.ctrl
    ctrl.barrier_out.wait()
    while True:
        collector.reset_if_needed(agent_inputs)  # Outside barrier?
        ctrl.barrier_in.wait()
        if ctrl.quit.value:
            logger.log('Quitting worker ...')
            break
        if ctrl.do_eval.value:
            eval_collector.collect_evaluation(
                ctrl.itr.value)  # Traj_infos to queue inside.
        else:
            agent_inputs, traj_infos, completed_infos = collector.collect_batch(
                agent_inputs, traj_infos, ctrl.itr.value)
            for info in completed_infos:
                c.traj_infos_queue.put(info)
        ctrl.barrier_out.wait()

    for env in envs + eval_envs:
        logger.log('Stopping env ...')
        env.close()