def populate_task(env, policy, scope=None):
    logger.log("Populating workers...")
    if singleton_pool.n_parallel > 1:
        singleton_pool.run_each(
            _worker_populate_task,
            [(pickle.dumps(env), pickle.dumps(policy), scope)] * singleton_pool.n_parallel
        )
    else:
        # avoid unnecessary copying
        G = _get_scoped_G(singleton_pool.G, scope)
        G.env = env
        G.policy = policy
    logger.log("Populated")
    def step(self, action_n):
        results = singleton_pool.run_each(
            worker_run_step,
            [(action_n, self.scope) for _ in self._alloc_env_ids],
        )
        results = [x for x in results if x is not None]
        ids, obs, rewards, dones, env_infos = list(zip(*results))
        ids = np.concatenate(ids)
        obs = self.observation_space.unflatten_n(np.concatenate(obs))
        rewards = np.concatenate(rewards)
        dones = np.concatenate(dones)
        env_infos = tensor_utils.split_tensor_dict_list(tensor_utils.concat_tensor_dict_list(env_infos))
        if env_infos is None:
            env_infos = [dict() for _ in range(self.num_envs)]

        items = list(zip(ids, obs, rewards, dones, env_infos))
        items = sorted(items, key=lambda x: x[0])

        ids, obs, rewards, dones, env_infos = list(zip(*items))

        obs = list(obs)
        rewards = np.asarray(rewards)
        dones = np.asarray(dones)

        self.ts += 1
        dones[self.ts >= self.max_path_length] = True

        reset_obs = self._run_reset(dones)
        for (i, done) in enumerate(dones):
            if done:
                obs[i] = reset_obs[i]
                self.ts[i] = 0
        return obs, rewards, dones, tensor_utils.stack_tensor_dict_list(list(env_infos))
    def __init__(self, env, n, max_path_length, scope=None):
        if scope is None:
            # initialize random scope
            scope = str(uuid.uuid4())

        envs_per_worker = int(np.ceil(n * 1.0 / singleton_pool.n_parallel))
        alloc_env_ids = []
        rest_alloc = n
        start_id = 0
        for _ in range(singleton_pool.n_parallel):
            n_allocs = min(envs_per_worker, rest_alloc)
            alloc_env_ids.append(list(range(start_id, start_id + n_allocs)))
            start_id += n_allocs
            rest_alloc = max(0, rest_alloc - envs_per_worker)

        singleton_pool.run_each(worker_init_envs, [(alloc, scope, env) for alloc in alloc_env_ids])

        self._alloc_env_ids = alloc_env_ids
        self._action_space = env.action_space
        self._observation_space = env.observation_space
        self._num_envs = n
        self.scope = scope
        self.ts = np.zeros(n, dtype='int')
        self.max_path_length = max_path_length
def sample_paths(
        policy_params,
        max_samples,
        max_path_length=np.inf,
        env_params=None,
        scope=None,
        reset_arg=None,
        show_prog_bar=True,
        multi_task=False):
    """
    :param policy_params: parameters for the policy. This will be updated on each worker process
    :param max_samples: desired maximum number of samples to be collected. The actual number of collected samples
    might be greater since all trajectories will be rolled out either until termination or until max_path_length is
    reached
    :param max_path_length: horizon / maximum length of a single trajectory
    :return: a list of collected paths
    """
    if multi_task:
        assert len(policy_params) == singleton_pool.n_parallel
        all_params = [(params, scope) for params in policy_params]
        singleton_pool.run_each(
            _worker_set_policy_params,
            all_params,
        )
    else:
        singleton_pool.run_each(
            _worker_set_policy_params,
            [(policy_params, scope)] * singleton_pool.n_parallel
        )
    if env_params is not None:
        singleton_pool.run_each(
            _worker_set_env_params,
            [(env_params, scope)] * singleton_pool.n_parallel
        )

    if multi_task:
        args = [(max_path_length, scope, arg) for arg in reset_arg]
        return singleton_pool.run_collect(
            _worker_collect_one_path,
            threshold=max_samples,
            args=args,
            show_prog_bar=show_prog_bar,
            multi_task=multi_task,
        )
    else:
        return singleton_pool.run_collect(
            _worker_collect_one_path,
            threshold=max_samples,
            args=(max_path_length, scope, reset_arg),
            show_prog_bar=show_prog_bar,
            multi_task=multi_task,
        )
    def _run_reset(self, dones):
        dones = np.asarray(dones)
        results = singleton_pool.run_each(
            worker_run_reset,
            [(dones, self.scope) for _ in self._alloc_env_ids],
        )
        ids, flat_obs = list(map(np.concatenate, list(zip(*results))))
        zipped = list(zip(ids, flat_obs))
        sorted_obs = np.asarray([x[1] for x in sorted(zipped, key=lambda x: x[0])])

        done_ids, = np.where(dones)
        done_flat_obs = sorted_obs[done_ids]
        done_unflat_obs = self.observation_space.unflatten_n(done_flat_obs)
        all_obs = [None] * self.num_envs
        done_cursor = 0
        for idx, done in enumerate(dones):
            if done:
                all_obs[idx] = done_unflat_obs[done_cursor]
                done_cursor += 1
        return all_obs
Exemple #6
0
 def start_worker(self):
     if singleton_pool.n_parallel > 1:
         singleton_pool.run_each(worker_init_tf)
     parallel_sampler.populate_task(self.algo.env, self.algo.policy)
     if singleton_pool.n_parallel > 1:
         singleton_pool.run_each(worker_init_tf_vars)
def set_seed(seed):
    singleton_pool.run_each(
        _worker_set_seed,
        [(seed + i,) for i in range(singleton_pool.n_parallel)]
    )
def terminate_task(scope=None):
    singleton_pool.run_each(
        _worker_terminate_task,
        [(scope,)] * singleton_pool.n_parallel
    )
def initialize(n_parallel):
    singleton_pool.initialize(n_parallel)
    singleton_pool.run_each(_worker_init, [(id,) for id in range(singleton_pool.n_parallel)])