示例#1
0
def populate_task(env, policy, scope=None):
    logger.log("Populating workers...")
    if singleton_pool.n_parallel > 1:
        singleton_pool.run_each(
            _worker_populate_task,
            [(pickle.dumps(env), pickle.dumps(policy), scope)] *
            singleton_pool.n_parallel)
    else:
        # avoid unnecessary copying
        G = _get_scoped_G(singleton_pool.G, scope)
        G.env = env
        G.policy = policy
    logger.log("Populated")
示例#2
0
def populate_task(env, policy, ma_mode, scope=None):
    logger.log("Populating workers...")
    logger.log("ma_mode={}".format(ma_mode))
    if singleton_pool.n_parallel > 1:
        singleton_pool.run_each(
            _worker_populate_task,
            [(pickle.dumps(env), pickle.dumps(policy), ma_mode, scope)] *
            singleton_pool.n_parallel)
    else:
        # avoid unnecessary copying
        G = _get_scoped_G(singleton_pool.G, scope)
        G.env = env
        if ma_mode == 'concurrent':
            G.policies = policy
        else:
            G.policy = policy
    logger.log("Populated")
示例#3
0
def sample_paths(policy_params,
                 max_samples,
                 ma_mode,
                 max_path_length=np.inf,
                 env_params=None,
                 scope=None):
    if ma_mode == 'concurrent':
        assert isinstance(policy_params, list)
    singleton_pool.run_each(_worker_set_policy_params,
                            [(policy_params, ma_mode, scope)] *
                            singleton_pool.n_parallel)
    if env_params is not None:
        singleton_pool.run_each(_worker_set_env_params, [(env_params, scope)] *
                                singleton_pool.n_parallel)

    return singleton_pool.run_collect(_worker_collect_path_one_env,
                                      threshold=max_samples,
                                      args=(max_path_length, ma_mode, scope),
                                      show_prog_bar=True)
示例#4
0
def sample_paths(policy_params,
                 max_samples,
                 max_path_length=np.inf,
                 env_params=None,
                 scope=None):
    """
    :param policy_params: parameters for the policy. This will be updated on each worker process
    :param max_samples: desired maximum number of samples to be collected. The actual number of collected samples
    might be greater since all trajectories will be rolled out either until termination or until max_path_length is
    reached
    :param max_path_length: horizon / maximum length of a single trajectory
    :return: a list of collected paths
    """
    singleton_pool.run_each(_worker_set_policy_params,
                            [(policy_params, scope)] *
                            singleton_pool.n_parallel)
    if env_params is not None:
        singleton_pool.run_each(_worker_set_env_params, [(env_params, scope)] *
                                singleton_pool.n_parallel)
    return singleton_pool.run_collect(_worker_collect_one_path,
                                      threshold=max_samples,
                                      args=(max_path_length, scope),
                                      show_prog_bar=True)
示例#5
0
def set_seed(seed):
    singleton_pool.run_each(_worker_set_seed,
                            [(seed + i, )
                             for i in range(singleton_pool.n_parallel)])
示例#6
0
def terminate_task(scope=None):
    singleton_pool.run_each(_worker_terminate_task,
                            [(scope, )] * singleton_pool.n_parallel)
示例#7
0
def initialize(n_parallel):
    singleton_pool.initialize(n_parallel)
    singleton_pool.run_each(_worker_init,
                            [(id, )
                             for id in range(singleton_pool.n_parallel)])