def _register_all(): from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.agents.registry import ALGORITHMS, get_agent_class from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys( )) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]: register_trainable(key, get_agent_class(key)) def _see_contrib(name): """Returns dummy agent class warning algo is in contrib/.""" class _SeeContrib(Trainer): _name = "SeeContrib" _default_config = with_common_config({}) def setup(self, config): raise NameError( "Please run `contrib/{}` instead.".format(name)) return _SeeContrib # also register the aliases minus contrib/ to give a good error message for key in list(CONTRIBUTED_ALGORITHMS.keys()): assert key.startswith("contrib/") alias = key.split("/", 1)[1] register_trainable(alias, _see_contrib(alias))
def _register_if_needed(cls, run_object): """Registers Trainable or Function at runtime. Assumes already registered if run_object is a string. Does not register lambdas because they could be part of variant generation. Also, does not inspect interface of given run_object. Arguments: run_object (str|function|class): Trainable to run. If string, assumes it is an ID and does not modify it. Otherwise, returns a string corresponding to the run_object name. Returns: A string representing the trainable identifier. """ if isinstance(run_object, six.string_types): return run_object elif isinstance(run_object, types.FunctionType): if run_object.__name__ == "<lambda>": logger.warning( "Not auto-registering lambdas - resolving as variant.") return run_object else: name = run_object.__name__ register_trainable(name, run_object) return name elif isinstance(run_object, type): name = run_object.__name__ register_trainable(name, run_object) return name else: raise TuneError("Improper 'run' - not string nor trainable.")
def _register_all(): for key in [ "PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "DDPG", "DDPG2", "APEX_DDPG", "__fake", "__sigmoid_fake_data", "__parameter_tuning" ]: from ray.rllib.agent import get_agent_class register_trainable(key, get_agent_class(key))
def register_if_needed(cls, run_object): """Registers Trainable or Function at runtime. Assumes already registered if run_object is a string. Also, does not inspect interface of given run_object. Arguments: run_object (str|function|class): Trainable to run. If string, assumes it is an ID and does not modify it. Otherwise, returns a string corresponding to the run_object name. Returns: A string representing the trainable identifier. """ if isinstance(run_object, six.string_types): return run_object elif isinstance(run_object, sample_from): logger.warning("Not registering trainable. Resolving as variant.") return run_object elif isinstance(run_object, type) or callable(run_object): name = "DEFAULT" if hasattr(run_object, "__name__"): name = run_object.__name__ else: logger.warning( "No name detected on trainable. Using {}.".format(name)) register_trainable(name, run_object) return name else: raise TuneError("Improper 'run' - not string nor trainable.")
def register_if_needed(cls, run_object): """Registers Trainable or Function at runtime. Assumes already registered if run_object is a string. Also, does not inspect interface of given run_object. Args: run_object (str|function|class): Trainable to run. If string, assumes it is an ID and does not modify it. Otherwise, returns a string corresponding to the run_object name. Returns: A string representing the trainable identifier. """ if isinstance(run_object, str): return run_object elif isinstance(run_object, Domain): logger.warning("Not registering trainable. Resolving as variant.") return run_object name = cls.get_trainable_name(run_object) try: register_trainable(name, run_object) except (TypeError, PicklingError) as e: extra_msg = ( "Other options: " "\n-Try reproducing the issue by calling " "`pickle.dumps(trainable)`. " "\n-If the error is typing-related, try removing " "the type annotations and try again." ) raise type(e)(str(e) + " " + extra_msg) from None return name
def _register_if_needed(cls, run_object): """Registers Trainable or Function at runtime. Assumes already registered if run_object is a string. Does not register lambdas because they could be part of variant generation. Also, does not inspect interface of given run_object. Arguments: run_object (str|function|class): Trainable to run. If string, assumes it is an ID and does not modify it. Otherwise, returns a string corresponding to the run_object name. Returns: A string representing the trainable identifier. """ if isinstance(run_object, six.string_types): return run_object elif isinstance(run_object, types.FunctionType): if run_object.__name__ == "<lambda>": logger.warning( "Not auto-registering lambdas - resolving as variant.") return run_object else: name = run_object.__name__ register_trainable(name, run_object) return name elif isinstance(run_object, type): name = run_object.__name__ register_trainable(name, run_object) return name else: raise TuneError("Improper 'run' - not string nor trainable.")
def _register_all(): for key in ["PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "__fake", "__sigmoid_fake_data", "__parameter_tuning"]: try: from ray.rllib.agent import get_agent_class register_trainable(key, get_agent_class(key)) except ImportError as e: print("Warning: could not import {}: {}".format(key, e))
def _register_all(): from ray.rllib.agents.registry import ALGORITHMS from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys( )) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]: from ray.rllib.agents.registry import get_agent_class register_trainable(key, get_agent_class(key))
def _register_all(): from ray.rllib.agents.registry import ALGORITHMS from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys( )) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]: from ray.rllib.agents.registry import get_agent_class register_trainable(key, get_agent_class(key))
def _register_all(): for key in [ "PPO", "ES", "DQN", "A3C", "BC", "__fake", "__sigmoid_fake_data" ]: try: register_trainable(key, get_agent_class(key)) except ImportError as e: print("Warning: could not import {}: {}".format(key, e))
def register_if_needed(cls, run_object): """Registers Trainable or Function at runtime. Assumes already registered if run_object is a string. Also, does not inspect interface of given run_object. Arguments: run_object (str|function|class): Trainable to run. If string, assumes it is an ID and does not modify it. Otherwise, returns a string corresponding to the run_object name. Returns: A string representing the trainable identifier. """ if isinstance(run_object, str): return run_object elif isinstance(run_object, Domain): logger.warning("Not registering trainable. Resolving as variant.") return run_object elif isinstance(run_object, type) or callable(run_object): name = "DEFAULT" if hasattr(run_object, "_name"): name = run_object._name elif hasattr(run_object, "__name__"): fn_name = run_object.__name__ if fn_name == "<lambda>": name = "lambda" elif fn_name.startswith("<"): name = "DEFAULT" else: name = fn_name else: logger.warning( "No name detected on trainable. Using {}.".format(name)) try: register_trainable(name, run_object) except (TypeError, PicklingError) as e: msg = ( f"{str(e)}. The trainable ({str(run_object)}) could not " "be serialized, which is needed for parallel execution. " "To diagnose the issue, try the following:\n\n" "\t- Run `tune.utils.diagnose_serialization(trainable)` " "to check if non-serializable variables are captured " "in scope.\n" "\t- Try reproducing the issue by calling " "`pickle.dumps(trainable)`.\n" "\t- If the error is typing-related, try removing " "the type annotations and try again.\n\n" "If you have any suggestions on how to improve " "this error message, please reach out to the " "Ray developers on github.com/ray-project/ray/issues/") raise type(e)(msg) from None return name else: raise TuneError("Improper 'run' - not string nor trainable.")
def testHasResourcesForTrialWithCaching(self): pgm = _PlacementGroupManager() pgf1 = PlacementGroupFactory([{"CPU": self.head_cpus}]) pgf2 = PlacementGroupFactory([{"CPU": self.head_cpus - 1}]) executor = RayTrialExecutor(reuse_actors=True) executor._pg_manager = pgm executor.set_max_pending_trials(1) def train(config): yield 1 yield 2 yield 3 yield 4 register_trainable("resettable", train) trial1 = Trial("resettable", placement_group_factory=pgf1) trial2 = Trial("resettable", placement_group_factory=pgf1) trial3 = Trial("resettable", placement_group_factory=pgf2) assert executor.has_resources_for_trial(trial1) assert executor.has_resources_for_trial(trial2) assert executor.has_resources_for_trial(trial3) executor._stage_and_update_status([trial1, trial2, trial3]) while not pgm.has_ready(trial1): time.sleep(1) executor._stage_and_update_status([trial1, trial2, trial3]) # Fill staging executor._stage_and_update_status([trial1, trial2, trial3]) assert executor.has_resources_for_trial(trial1) assert executor.has_resources_for_trial(trial2) assert not executor.has_resources_for_trial(trial3) executor._start_trial(trial1) executor._stage_and_update_status([trial1, trial2, trial3]) executor.pause_trial( trial1) # Caches the PG and removes a PG from staging assert len(pgm._staging_futures) == 0 # This will re-schedule a placement group pgm.reconcile_placement_groups([trial1, trial2]) assert len(pgm._staging_futures) == 1 assert not pgm.can_stage() # We should still have resources for this trial as it has a cached PG assert executor.has_resources_for_trial(trial1) assert executor.has_resources_for_trial(trial2) assert not executor.has_resources_for_trial(trial3)
def _register_all(): for key in [ "PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "__fake", "__sigmoid_fake_data", "__parameter_tuning" ]: try: from ray.rllib.agent import get_agent_class register_trainable(key, get_agent_class(key)) except ImportError as e: print("Warning: could not import {}: {}".format(key, e))
def load_algorithms(CUSTOM_ALGORITHMS): """ This function loads the custom algorithms implemented in this repository, and registers them with the tune registry """ from ray.tune import registry for _custom_algorithm_name in CUSTOM_ALGORITHMS: _class = CUSTOM_ALGORITHMS[_custom_algorithm_name]() registry.register_trainable(_custom_algorithm_name, _class)
def register_if_needed(cls, run_object): """Registers Trainable or Function at runtime. Assumes already registered if run_object is a string. Also, does not inspect interface of given run_object. Arguments: run_object (str|function|class): Trainable to run. If string, assumes it is an ID and does not modify it. Otherwise, returns a string corresponding to the run_object name. Returns: A string representing the trainable identifier. """ if isinstance(run_object, str): return run_object elif isinstance(run_object, Domain): logger.warning("Not registering trainable. Resolving as variant.") return run_object elif isinstance(run_object, type) or callable(run_object): name = "DEFAULT" if hasattr(run_object, "_name"): name = run_object._name elif hasattr(run_object, "__name__"): fn_name = run_object.__name__ if fn_name == "<lambda>": name = "lambda" elif fn_name.startswith("<"): name = "DEFAULT" else: name = fn_name elif ( isinstance(run_object, partial) and hasattr(run_object, "func") and hasattr(run_object.func, "__name__") ): name = run_object.func.__name__ else: logger.warning("No name detected on trainable. Using {}.".format(name)) try: register_trainable(name, run_object) except (TypeError, PicklingError) as e: extra_msg = ( "Other options: " "\n-Try reproducing the issue by calling " "`pickle.dumps(trainable)`. " "\n-If the error is typing-related, try removing " "the type annotations and try again." ) raise type(e)(str(e) + " " + extra_msg) from None return name else: raise TuneError("Improper 'run' - not string nor trainable.")
def AdaptDLTrainableCreator(func: Callable, num_workers: int = 1, group: int = 0, num_cpus_per_worker: int = 1, num_workers_per_host: Optional[int] = None, backend: str = "gloo", timeout_s: int = NCCL_TIMEOUT_S, use_gpu=None): """ Trainable creator for AdaptDL's elastic Trials""" if config.default_device() == "GPU": backend = "nccl" class AdaptDLTrainable(_TorchTrainable): """ Similar to DistributedTrainable but for AdaptDLTrials.""" def setup(self, config: Dict): """ Delay-patch methods when the Trainable actors are first created""" with patch(target="ray.tune.integration.torch.setup_process_group", new=P.setup_process_group), \ patch(target='ray.tune.integration.torch.wrap_function', new=P.wrap_function_patched): _TorchTrainable.setup(self, config) # Override the default resources and use custom PG factory @classmethod def default_resource_request(cls, config: Dict) -> Resources: return None def get_sched_hints(self): return ray.get(self.workers[0].get_sched_hints.remote()) def save_all_states(self, trial_state): return ray.get(self.workers[0].save_all_states.remote(trial_state)) @classmethod def default_process_group_parameters(self) -> Dict: return dict(timeout=timedelta(timeout_s), backend=backend) AdaptDLTrainable._function = func AdaptDLTrainable._num_workers = num_workers # Set number of GPUs if we're using them, this is later used when spawning # the trial actors if config.default_device() == "GPU": AdaptDLTrainable._num_gpus_per_worker = 1 else: AdaptDLTrainable._num_gpus_per_worker = 0 # Trainables are named after number of replicas they spawn. This is # essential to associate the right Trainable with the right Trial and PG. AdaptDLTrainable.__name__ = AdaptDLTrainable.__name__.split("_")[0] + \ f"_{num_workers}" + f"_{group}" register_trainable(AdaptDLTrainable.__name__, AdaptDLTrainable) return AdaptDLTrainable
def main(): register_doom_envs_rllib() register_dmlab_envs_rllib() ModelCatalog.register_custom_model('vizdoom_vision_model', VizdoomVisionNetwork) def custom_ppo(): return PPOTrainer.with_updates(default_policy=CustomPPOTFPolicy) def custom_appo(): return APPOTrainer.with_updates( default_policy=CustomAPPOTFPolicy, get_policy_class=lambda _: CustomAPPOTFPolicy, ) register_trainable('CUSTOM_PPO', custom_ppo()) register_trainable('CUSTOM_APPO', custom_appo()) parser = create_parser() args = parser.parse_args() run_experiment(args, parser)
def _register_all(): from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.registry import ALGORITHMS, get_algorithm_class from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS for key in (list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys()) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]): register_trainable(key, get_algorithm_class(key)) def _see_contrib(name): """Returns dummy agent class warning algo is in contrib/.""" class _SeeContrib(Algorithm): def setup(self, config): raise NameError( "Please run `contrib/{}` instead.".format(name)) return _SeeContrib # Also register the aliases minus contrib/ to give a good error message. for key in list(CONTRIBUTED_ALGORITHMS.keys()): assert key.startswith("contrib/") alias = key.split("/", 1)[1] if alias not in ALGORITHMS: register_trainable(alias, _see_contrib(alias))
def main(args): ray.init() register_trainable("MADDPG", MADDPGTrainer) # Create test environment. env = parallel_env(simple_spread_v2)() # Register env env_name = 'simple_spread' register_env(env_name, lambda _: parallel_env(simple_spread_v2)()) def gen_policy(i): use_local_critic = [ args.adv_policy == "ddpg" if i < args.num_adversaries else args.good_policy == "ddpg" for i, _ in enumerate(env.agents) ] return ( None, env.observation_space, env.action_space, { "agent_id": i, "use_local_critic": use_local_critic[i], "obs_space_dict": dict(zip([0, 1, 2], [env.observation_space, env.observation_space, env.observation_space])), "act_space_dict": dict(zip([0, 1, 2], [env.action_space, env.action_space, env.action_space])), } ) policies = {agent: gen_policy(i) for i, agent in enumerate(env.agents)} ray.tune.run( "MADDPG", stop={ "episodes_total": args.num_episodes, }, checkpoint_at_end=True, checkpoint_freq=args.checkpoint_freq, local_dir=args.local_dir, # restore='/home/jiekaijia/PycharmProjects/pettingzoo_comunication/ray_results/MADDPG/MADDPG_mpe_2ff0a_00000_0_2021-06-07_14-23-38/checkpoint_000119/checkpoint-119', # args.restore, config={ # === Log === "log_level": "ERROR", 'render_env': True, # === Environment === 'env': env_name, # "env_config": {'max_cycles': 25, 'num_agents': 3, 'local_ratio': 0.5}, "num_envs_per_worker": args.num_envs_per_worker, "horizon": args.max_episode_len, # === Policy Config === # --- Model --- "good_policy": args.good_policy, "adv_policy": args.adv_policy, "actor_hiddens": [args.num_units] * 2, "actor_hidden_activation": "relu", "critic_hiddens": [args.num_units] * 2, "critic_hidden_activation": "relu", "n_step": args.n_step, "gamma": args.gamma, # --- Exploration --- "tau": 0.01, # --- Replay buffer --- "buffer_size": int(1e6), # --- Optimization --- "actor_lr": args.lr, "critic_lr": args.lr, "learning_starts": args.train_batch_size * args.max_episode_len, "rollout_fragment_length": args.sample_batch_size, "train_batch_size": args.train_batch_size, "batch_mode": "truncate_episodes", # --- Parallelism --- "num_workers": args.num_workers, "num_gpus": int(os.environ.get('RLLIB_NUM_GPUS', '0')), # "num_gpus_per_worker": 0, # === Multi-agent setting === "multiagent": { "policies": policies, "policy_mapping_fn": lambda agent_id: agent_id }} )
from __future__ import absolute_import from __future__ import division from __future__ import print_function import ray from ray.tune.registry import register_trainable, register_env from ray.tune import run_experiments, grid_search from ray.rllib.models.catalog import ModelCatalog from maml import MAMLAgent from point_env import PointEnv from reset_wrapper import ResetWrapper from fcnet import FullyConnectedNetwork register_trainable("MAML", MAMLAgent) env_cls = PointEnv register_env(env_cls.__name__, lambda env_config: ResetWrapper(env_cls(), env_config)) ModelCatalog.register_custom_model("maml_mlp", FullyConnectedNetwork) # ray.init() ray.init(redis_address="localhost:32222") config = { "random_seed": grid_search([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), "inner_lr": grid_search([0.01]), "inner_grad_clip": grid_search([10.0, 20.0, 30.0, 40.0]), "clip_param": grid_search([0.1, 0.2, 0.3]), "vf_loss_coeff": grid_search([0.01, 0.02, 0.05, 0.1, 0.2]), "vf_clip_param": grid_search([5.0, 10.0, 15.0, 20.0]),
def main(args): #ray.init(redis_max_memory=int(1e10), object_store_memory=int(3e9)) #memory=int(6200000000) #ray.init(memory=int(4200000000), object_store_memory=int(2200000000), num_gpus=1, num_cpus=6) """ ray.init(redis_max_memory=int(ray.utils.get_system_memory() * 0.4), memory=int(ray.utils.get_system_memory() * 0.2), object_store_memory=int(ray.utils.get_system_memory() * 0.2), huge_pages=False, num_gpus=1, num_cpus=6, temp_dir='/mnt/hdd-a500/Ray_temp/') """ ray.init( redis_max_memory=int(ray.utils.get_system_memory() * 0.4), memory=int(ray.utils.get_system_memory() * 0.2), object_store_memory=int(ray.utils.get_system_memory() * 0.2), # huge_pages=False, num_gpus=args.num_gpus, num_cpus=6, temp_dir=args.temp_dir) MADDPGAgent = MADDPGTrainer.with_updates(mixins=[CustomStdOut]) register_trainable("MADDPG", MADDPGAgent) def env_creater(mpe_args): return MultiAgentParticleEnv(**mpe_args) register_env("mpe", env_creater) env = env_creater({ "scenario_name": args.scenario, }) def gen_policy(i): use_local_critic = [ args.adv_policy == "ddpg" if i < args.num_adversaries else args.good_policy == "ddpg" for i in range(env.num_agents) ] return (None, env.observation_space_dict[i], env.action_space_dict[i], { "agent_id": i, "use_local_critic": use_local_critic[i], "obs_space_dict": env.observation_space_dict, "act_space_dict": env.action_space_dict, }) policies = { "policy_%d" % i: gen_policy(i) for i in range(len(env.observation_space_dict)) } policy_ids = list(policies.keys()) def policy_mapping_fn(agent_id): return policy_ids[agent_id] exp_name = "{}{}".format( args.scenario.replace("_", "").replace("-", ""), "_{}".format(args.add_postfix) if args.add_postfix != "" else "") run_experiments( { exp_name: { "run": "contrib/MADDPG", "env": "mpe", "stop": { "episodes_total": args.num_episodes, }, "checkpoint_freq": args.checkpoint_freq, "local_dir": args.local_dir, "restore": args.restore, "config": { # === Log === "log_level": "ERROR", # === Environment === "env_config": { "scenario_name": args.scenario, }, "num_envs_per_worker": args.num_envs_per_worker, "horizon": args.max_episode_len, # === Policy Config === # --- Model --- "good_policy": args.good_policy, "adv_policy": args.adv_policy, "actor_hiddens": [args.num_units] * 2, "actor_hidden_activation": "relu", "critic_hiddens": [args.num_units] * 2, "critic_hidden_activation": "relu", "n_step": args.n_step, "gamma": args.gamma, # --- Exploration --- "tau": args.tau, # --- Replay buffer --- "buffer_size": args.replay_buffer, # int(10000), # int(1e6) # --- Optimization --- "actor_lr": args.lr, "critic_lr": args.lr, "learning_starts": args.train_batch_size * args.max_episode_len, "sample_batch_size": args.sample_batch_size, "train_batch_size": args.train_batch_size, "batch_mode": "truncate_episodes", # --- Parallelism --- "num_workers": args.num_workers, "num_gpus": args.num_gpus, "num_gpus_per_worker": 0, # === Multi-agent setting === "multiagent": { "policies": policies, "policy_mapping_fn": ray.tune.function(policy_mapping_fn) }, }, }, }, verbose=0, reuse_actors=False) # reuse_actors=True - messes up the results
_policy_graph = filter_var_policy_factory(scope_not_to_freeze='film') class VisionFrozenApexLoadedWeight(dqn.ApexAgent): _agent_name = "APEX_VISION_FROZEN_LOADED_WEIGHT" _default_config = dqn.apex.APEX_DEFAULT_CONFIG _policy_graph = loading_class_factory( filter_var_policy_factory(scope_not_to_freeze='film'), 'test_weight.npy') class FilmFrozenApexLoadedWeight(dqn.ApexAgent): _agent_name = "APEX_FILM_FROZEN_LOADED_WEIGHT" _default_config = dqn.apex.APEX_DEFAULT_CONFIG _policy_graph = loading_class_factory( filter_var_policy_factory(scope_to_freeze='film'), 'test_weight.npy') # Frozen GraphModel register_trainable("apex_vision_frozen", VisionFrozenApex) register_trainable("apex_film_frozen", FilmFrozenApex) # Model loading register_trainable("apex_vision_frozen_loaded_weight", VisionFrozenApexLoadedWeight) register_trainable("apex_film_frozen_loaded_weight", FilmFrozenApexLoadedWeight)
def _register_all(): register_trainable("PPO", ppo.PPOAgent) register_trainable("ES", es.ESAgent) register_trainable("DQN", dqn.DQNAgent) register_trainable("A3C", a3c.A3CAgent) register_trainable("__fake", _MockAgent) register_trainable("__sigmoid_fake_data", _SigmoidFakeData)
if alg_name == "maml": agent_cls = MAMLAgent elif alg_name == "meta-sgd": agent_cls = MetaSGDAgent elif alg_name == "maesn": agent_cls = MAESNAgent elif alg_name == "tesp": agent_cls = TESPAgent all_agent_cls = [agent_cls] all_model_cls = [RLlibMLP, RLlibMAESN, RLlibTESP, RLlibTESPWithAdapPolicy] register_env(env_cls.__name__, lambda env_config: ResetWrapper(env_cls(env_config), env_config)) for agent_cls in all_agent_cls: register_trainable(agent_cls._agent_name, agent_cls) for model_cls in all_model_cls: ModelCatalog.register_custom_model(model_cls.__name__, model_cls) def get_config(model_cls_name): config = { "random_seed": grid_search([1]), "inner_lr": grid_search([0.001]), "inner_lr_bound": 0.1, "inner_grad_clip": 60.0, "num_inner_updates": 3, "outer_lr": grid_search([3e-4]), "num_sgd_iter": 20, "clip_param": 0.15, "model_loss_coeff": grid_search([0.01]),
from __future__ import absolute_import from __future__ import division from __future__ import print_function from ray.tune.error import TuneError from ray.tune.tune import run_experiments from ray.tune.registry import register_env, register_trainable from ray.tune.result import TrainingResult from ray.tune.script_runner import ScriptRunner from ray.tune.trainable import Trainable from ray.tune.variant_generator import grid_search register_trainable("script", ScriptRunner) __all__ = [ "Trainable", "TrainingResult", "TuneError", "grid_search", "register_env", "register_trainable", "run_experiments", ]
def diagnose_serialization(trainable): """Utility for detecting why your trainable function isn't serializing. Args: trainable (func): The trainable object passed to tune.run(trainable). Currently only supports Function API. Returns: bool | set of unserializable objects. Example: .. code-block:: python import threading # this is not serializable e = threading.Event() def test(): print(e) diagnose_serialization(test) # should help identify that 'e' should be moved into # the `test` scope. # correct implementation def test(): e = threading.Event() print(e) assert diagnose_serialization(test) is True """ from ray.tune.registry import register_trainable, check_serializability def check_variables(objects, failure_set, printer): for var_name, variable in objects.items(): msg = None try: check_serializability(var_name, variable) status = "PASSED" except Exception as e: status = "FAILED" msg = f"{e.__class__.__name__}: {str(e)}" failure_set.add(var_name) printer(f"{str(variable)}[name='{var_name}'']... {status}") if msg: printer(msg) print(f"Trying to serialize {trainable}...") try: register_trainable("__test:" + str(trainable), trainable, warn=False) print("Serialization succeeded!") return True except Exception as e: print(f"Serialization failed: {e}") print("Inspecting the scope of the trainable by running " f"`inspect.getclosurevars({str(trainable)})`...") closure = inspect.getclosurevars(trainable) failure_set = set() if closure.globals: print(f"Detected {len(closure.globals)} global variables. " "Checking serializability...") check_variables(closure.globals, failure_set, lambda s: print(" " + s)) if closure.nonlocals: print(f"Detected {len(closure.nonlocals)} nonlocal variables. " "Checking serializability...") check_variables(closure.nonlocals, failure_set, lambda s: print(" " + s)) if not failure_set: print("Nothing was found to have failed the diagnostic test, though " "serialization did not succeed. Feel free to raise an " "issue on github.") return failure_set else: print(f"Variable(s) {failure_set} was found to be non-serializable. " "Consider either removing the instantiation/imports " "of these objects or moving them into the scope of " "the trainable. ") return failure_set
def main(args): ray.init(redis_max_memory=int(1e10), object_store_memory=int(3e9)) MADDPGAgent = maddpg.MADDPGTrainer.with_updates(mixins=[CustomStdOut]) register_trainable("MADDPG", MADDPGAgent) def env_creater(mpe_args): return MultiAgentParticleEnv(**mpe_args) register_env("mpe", env_creater) env = env_creater({ "scenario_name": args.scenario, }) def gen_policy(i): use_local_critic = [ args.adv_policy == "ddpg" if i < args.num_adversaries else args.good_policy == "ddpg" for i in range(env.num_agents) ] return (None, env.observation_space_dict[i], env.action_space_dict[i], { "agent_id": i, "use_local_critic": use_local_critic[i], "obs_space_dict": env.observation_space_dict, "act_space_dict": env.action_space_dict, }) policies = { "policy_%d" % i: gen_policy(i) for i in range(len(env.observation_space_dict)) } policy_ids = list(policies.keys()) run_experiments( { "MADDPG_RLLib": { "run": "contrib/MADDPG", "env": "mpe", "stop": { "episodes_total": args.num_episodes, }, "checkpoint_freq": args.checkpoint_freq, "local_dir": args.local_dir, "restore": args.restore, "config": { # === Log === "log_level": "ERROR", # === Environment === "env_config": { "scenario_name": args.scenario, }, "num_envs_per_worker": args.num_envs_per_worker, "horizon": args.max_episode_len, # === Policy Config === # --- Model --- "good_policy": args.good_policy, "adv_policy": args.adv_policy, "actor_hiddens": [args.num_units] * 2, "actor_hidden_activation": "relu", "critic_hiddens": [args.num_units] * 2, "critic_hidden_activation": "relu", "n_step": args.n_step, "gamma": args.gamma, # --- Exploration --- "tau": 0.01, # --- Replay buffer --- "buffer_size": args.replay_buffer, # --- Optimization --- "actor_lr": args.lr, "critic_lr": args.lr, "learning_starts": args.train_batch_size * args.max_episode_len, "sample_batch_size": args.sample_batch_size, "train_batch_size": args.train_batch_size, "batch_mode": "truncate_episodes", # --- Parallelism --- "num_workers": args.num_workers, "num_gpus": args.num_gpus, "num_gpus_per_worker": 0, # === Multi-agent setting === "multiagent": { "policies": policies, "policy_mapping_fn": ray.tune.function(lambda i: policy_ids[i]) }, }, }, }, verbose=0)