def _worker_set_policy_params(G, params, ma_mode, scope=None): G = _get_scoped_G(G, scope) if ma_mode == 'concurrent': for pid, policy in enumerate(G.policies): policy.set_param_values(params[pid]) else: G.policy.set_param_values(params)
def _worker_populate_task(G, env, policy, ma_mode, scope=None): # TODO: better term for both policy/policies G = _get_scoped_G(G, scope) G.env = pickle.loads(env) if ma_mode == 'concurrent': G.policies = pickle.loads(policy) assert isinstance(G.policies, list) else: G.policy = pickle.loads(policy)
def _worker_collect_path_one_env(G, max_path_length, ma_mode, scope=None): G = _get_scoped_G(G, scope) if ma_mode == 'centralized': path = cent_rollout(G.env, G.policy, max_path_length) return path, len(path['rewards']) elif ma_mode == 'decentralized': paths = dec_rollout(G.env, G.policy, max_path_length) lengths = [len(path['rewards']) for path in paths] return paths, sum(lengths) elif ma_mode == 'concurrent': paths = conc_rollout(G.env, G.policies, max_path_length) lengths = [len(path['rewards']) for path in paths] return paths, lengths[0] else: raise NotImplementedError("incorrect rollout type")
def _worker_terminate_task(G, scope=None): G = _get_scoped_G(G, scope) if getattr(G, "env", None): G.env.terminate() G.env = None if getattr(G, "policy", None): G.policy.terminate() G.policy = None if getattr(G, "policies", None): for policy in G.policies: policy.terminate() G.policies = None if getattr(G, "sess", None): G.sess.close() G.sess = None
def populate_task(env, policy, ma_mode, scope=None): logger.log("Populating workers...") logger.log("ma_mode={}".format(ma_mode)) if singleton_pool.n_parallel > 1: singleton_pool.run_each( _worker_populate_task, [(pickle.dumps(env), pickle.dumps(policy), ma_mode, scope)] * singleton_pool.n_parallel) else: # avoid unnecessary copying G = _get_scoped_G(singleton_pool.G, scope) G.env = env if ma_mode == 'concurrent': G.policies = policy else: G.policy = policy logger.log("Populated")