def populate_task(env, policy, scope=None): logger.log("Populating workers...") if singleton_pool.n_parallel > 1: singleton_pool.run_each( _worker_populate_task, [(pickle.dumps(env), pickle.dumps(policy), scope)] * singleton_pool.n_parallel) else: # avoid unnecessary copying G = _get_scoped_G(singleton_pool.G, scope) G.env = env G.policy = policy logger.log("Populated")
def populate_task(env, policy, ma_mode, scope=None): logger.log("Populating workers...") logger.log("ma_mode={}".format(ma_mode)) if singleton_pool.n_parallel > 1: singleton_pool.run_each( _worker_populate_task, [(pickle.dumps(env), pickle.dumps(policy), ma_mode, scope)] * singleton_pool.n_parallel) else: # avoid unnecessary copying G = _get_scoped_G(singleton_pool.G, scope) G.env = env if ma_mode == 'concurrent': G.policies = policy else: G.policy = policy logger.log("Populated")
def sample_paths(policy_params, max_samples, ma_mode, max_path_length=np.inf, env_params=None, scope=None): if ma_mode == 'concurrent': assert isinstance(policy_params, list) singleton_pool.run_each(_worker_set_policy_params, [(policy_params, ma_mode, scope)] * singleton_pool.n_parallel) if env_params is not None: singleton_pool.run_each(_worker_set_env_params, [(env_params, scope)] * singleton_pool.n_parallel) return singleton_pool.run_collect(_worker_collect_path_one_env, threshold=max_samples, args=(max_path_length, ma_mode, scope), show_prog_bar=True)
def sample_paths(policy_params, max_samples, max_path_length=np.inf, env_params=None, scope=None): """ :param policy_params: parameters for the policy. This will be updated on each worker process :param max_samples: desired maximum number of samples to be collected. The actual number of collected samples might be greater since all trajectories will be rolled out either until termination or until max_path_length is reached :param max_path_length: horizon / maximum length of a single trajectory :return: a list of collected paths """ singleton_pool.run_each(_worker_set_policy_params, [(policy_params, scope)] * singleton_pool.n_parallel) if env_params is not None: singleton_pool.run_each(_worker_set_env_params, [(env_params, scope)] * singleton_pool.n_parallel) return singleton_pool.run_collect(_worker_collect_one_path, threshold=max_samples, args=(max_path_length, scope), show_prog_bar=True)
def set_seed(seed): singleton_pool.run_each(_worker_set_seed, [(seed + i, ) for i in range(singleton_pool.n_parallel)])
def terminate_task(scope=None): singleton_pool.run_each(_worker_terminate_task, [(scope, )] * singleton_pool.n_parallel)
def initialize(n_parallel): singleton_pool.initialize(n_parallel) singleton_pool.run_each(_worker_init, [(id, ) for id in range(singleton_pool.n_parallel)])