Ejemplo n.º 1
0
def assemble_worker_kwargs(n, par_objs, seed, affinities, common_kwargs):
    if len(par_objs) == 5:
        ctrl, sync, envs_buf, step_bufs, traj_infos_queue = par_objs
        eval_step_bufs = None
    else:
        ctrl, sync, envs_buf, step_bufs, traj_infos_queue, eval_step_bufs = par_objs
    worker_segs_lists = view_worker_segs_bufs(envs_buf.segs_view, 2 * n)
    cpu_list = affinities.get("sim_cpus", list(range(2 * n)))
    common_par_objs = (ctrl, traj_infos_queue)
    worker_kwargs_list = list()
    i = 0
    for group in range(2):
        for rank in range(n):
            w_sync = struct(step_blocker=sync.step_blockers[group][rank],
                            act_waiter=sync.act_waiters[group][rank])
            w_segs_buf = worker_segs_lists[i]
            w_step_buf = step_bufs[group]
            w_par_objs = (w_sync, w_segs_buf, w_step_buf)
            if eval_step_bufs is not None:
                w_eval_step_buf = eval_step_bufs[group]
                w_par_objs += (w_eval_step_buf, )
            w_kwargs = dict(group=group,
                            rank=rank,
                            unique_ID=i,
                            seed=seed + i,
                            cpu=cpu_list[i],
                            par_objs=common_par_objs + w_par_objs,
                            **common_kwargs)
            worker_kwargs_list.append(w_kwargs)
            i += 1
    return worker_kwargs_list
Ejemplo n.º 2
0
def batch_buffer(example, length, shared=False):
    if isinstance(example, dict):
        buf = struct()
        for k, v in example.items():
            buf[k] = batch_buffer(v, length, shared)
    else:
        buf = build_array(example, length, shared)
    return buf
Ejemplo n.º 3
0
def _recurse_segment(batch_buffer, start, segment_length):
    segment = struct()
    for k, v in batch_buffer.items():
        if isinstance(v, dict):
            segment[k] = _recurse_segment(v, start, segment_length)
        else:
            segment[k] = v[start:start + segment_length]
    return segment
Ejemplo n.º 4
0
def view_serve_groups(policy_buf, extra_observations=None):
    half = buffer_length(policy_buf) // 2
    pol_group_0 = struct(
        actions=policy_buf.actions[:half],
        agent_infos={k: v[:half]
                     for k, v in policy_buf.agent_infos.items()},
    )
    pol_group_1 = struct(
        actions=policy_buf.actions[half:],
        agent_infos={k: v[half:]
                     for k, v in policy_buf.agent_infos.items()},
    )
    if extra_observations is not None:
        half_obs = len(extra_observations) // 2
        extra_obs_0 = extra_observations[:half_obs]
        extra_obs_1 = extra_observations[half_obs:]
        return (pol_group_0, pol_group_1), (extra_obs_0, extra_obs_1)
    return (pol_group_0, pol_group_1)
Ejemplo n.º 5
0
 def build_par_objs(self, n_runners):
     barrier = mp.Barrier(n_runners)
     mgr = mp.Manager()
     par_dict = mgr.dict()
     traj_infos_queue = mp.Queue()
     par_objs = struct(
         barrier=barrier,
         dict=par_dict,
         traj_infos_queue=traj_infos_queue,
     )
     return par_objs
Ejemplo n.º 6
0
def build_par_objs(env, n, envs_per, horizon, extra_obs, eval_envs_per=None):
    ctrl = struct(
        quit=mp.RawValue(ctypes.c_bool, False),
        do_eval=mp.RawValue(ctypes.c_bool, False),
        barrier_in=mp.Barrier(2 * n + 1),
        barrier_out=mp.Barrier(2 * n + 1),
    )
    sync = struct(
        step_blockers=[[mp.Semaphore(0) for _ in range(n)] for _ in range(2)],
        act_waiters=[[mp.Semaphore(0) for _ in range(n)] for _ in range(2)],
    )
    envs_buf = build_env_buffer(env, 2 * n * envs_per * horizon, horizon,
                                extra_obs)
    step_bufs = [build_step_buffer(env.spec, n * envs_per) for _ in range(2)]
    traj_infos_queue = mp.Queue()
    ret = (ctrl, sync, envs_buf, step_bufs, traj_infos_queue)
    if eval_envs_per is not None:
        eval_step_bufs = [
            build_step_buffer(env.spec, n * eval_envs_per) for _ in range(2)
        ]
        ret += (eval_step_bufs, )
    return ret
Ejemplo n.º 7
0
def view_segments(batch_buffer, segment_length):
    length = buffer_length(batch_buffer)
    if length % segment_length != 0:
        raise ValueError("Buffer length ({}) not divisible by requested "
            "segment_length ({})".format(length, segment_length))
    num_segments = length // segment_length
    segments = list()
    i = 0
    for _ in range(num_segments):
        segment = struct()
        for k, v in batch_buffer.items():
            if isinstance(v, struct):
                segment[k] = _recurse_segment(v, i, segment_length)
            else:
                segment[k] = v[i:i + segment_length]
        segments.append(segment)
        i += segment_length
    return segments