コード例 #1
0
def _worker_set_policy_params(G, params, ma_mode, scope=None):
    G = _get_scoped_G(G, scope)
    if ma_mode == 'concurrent':
        for pid, policy in enumerate(G.policies):
            policy.set_param_values(params[pid])
    else:
        G.policy.set_param_values(params)
コード例 #2
0
def _worker_populate_task(G, env, policy, ma_mode, scope=None):
    # TODO: better term for both policy/policies
    G = _get_scoped_G(G, scope)
    G.env = pickle.loads(env)
    if ma_mode == 'concurrent':
        G.policies = pickle.loads(policy)
        assert isinstance(G.policies, list)
    else:
        G.policy = pickle.loads(policy)
コード例 #3
0
def _worker_collect_path_one_env(G, max_path_length, ma_mode, scope=None):
    G = _get_scoped_G(G, scope)
    if ma_mode == 'centralized':
        path = cent_rollout(G.env, G.policy, max_path_length)
        return path, len(path['rewards'])
    elif ma_mode == 'decentralized':
        paths = dec_rollout(G.env, G.policy, max_path_length)
        lengths = [len(path['rewards']) for path in paths]
        return paths, sum(lengths)
    elif ma_mode == 'concurrent':
        paths = conc_rollout(G.env, G.policies, max_path_length)
        lengths = [len(path['rewards']) for path in paths]
        return paths, lengths[0]
    else:
        raise NotImplementedError("incorrect rollout type")
コード例 #4
0
def _worker_terminate_task(G, scope=None):
    G = _get_scoped_G(G, scope)
    if getattr(G, "env", None):
        G.env.terminate()
        G.env = None
    if getattr(G, "policy", None):
        G.policy.terminate()
        G.policy = None
    if getattr(G, "policies", None):
        for policy in G.policies:
            policy.terminate()
        G.policies = None
    if getattr(G, "sess", None):
        G.sess.close()
        G.sess = None
コード例 #5
0
def populate_task(env, policy, ma_mode, scope=None):
    logger.log("Populating workers...")
    logger.log("ma_mode={}".format(ma_mode))
    if singleton_pool.n_parallel > 1:
        singleton_pool.run_each(
            _worker_populate_task,
            [(pickle.dumps(env), pickle.dumps(policy), ma_mode, scope)] *
            singleton_pool.n_parallel)
    else:
        # avoid unnecessary copying
        G = _get_scoped_G(singleton_pool.G, scope)
        G.env = env
        if ma_mode == 'concurrent':
            G.policies = policy
        else:
            G.policy = policy
    logger.log("Populated")