Ejemplo n.º 1
0
def _register_all():
    from ray.rllib.agents.trainer import Trainer, with_common_config
    from ray.rllib.agents.registry import ALGORITHMS, get_agent_class
    from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS

    for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys(
    )) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
        register_trainable(key, get_agent_class(key))

    def _see_contrib(name):
        """Returns dummy agent class warning algo is in contrib/."""

        class _SeeContrib(Trainer):
            _name = "SeeContrib"
            _default_config = with_common_config({})

            def setup(self, config):
                raise NameError(
                    "Please run `contrib/{}` instead.".format(name))

        return _SeeContrib

    # also register the aliases minus contrib/ to give a good error message
    for key in list(CONTRIBUTED_ALGORITHMS.keys()):
        assert key.startswith("contrib/")
        alias = key.split("/", 1)[1]
        register_trainable(alias, _see_contrib(alias))
Ejemplo n.º 2
0
    def _register_if_needed(cls, run_object):
        """Registers Trainable or Function at runtime.

        Assumes already registered if run_object is a string. Does not
        register lambdas because they could be part of variant generation.
        Also, does not inspect interface of given run_object.

        Arguments:
            run_object (str|function|class): Trainable to run. If string,
                assumes it is an ID and does not modify it. Otherwise,
                returns a string corresponding to the run_object name.

        Returns:
            A string representing the trainable identifier.
        """

        if isinstance(run_object, six.string_types):
            return run_object
        elif isinstance(run_object, types.FunctionType):
            if run_object.__name__ == "<lambda>":
                logger.warning(
                    "Not auto-registering lambdas - resolving as variant.")
                return run_object
            else:
                name = run_object.__name__
                register_trainable(name, run_object)
                return name
        elif isinstance(run_object, type):
            name = run_object.__name__
            register_trainable(name, run_object)
            return name
        else:
            raise TuneError("Improper 'run' - not string nor trainable.")
Ejemplo n.º 3
0
def _register_all():
    for key in [
            "PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "DDPG", "DDPG2",
            "APEX_DDPG", "__fake", "__sigmoid_fake_data", "__parameter_tuning"
    ]:
        from ray.rllib.agent import get_agent_class
        register_trainable(key, get_agent_class(key))
Ejemplo n.º 4
0
    def register_if_needed(cls, run_object):
        """Registers Trainable or Function at runtime.

        Assumes already registered if run_object is a string.
        Also, does not inspect interface of given run_object.

        Arguments:
            run_object (str|function|class): Trainable to run. If string,
                assumes it is an ID and does not modify it. Otherwise,
                returns a string corresponding to the run_object name.

        Returns:
            A string representing the trainable identifier.
        """

        if isinstance(run_object, six.string_types):
            return run_object
        elif isinstance(run_object, sample_from):
            logger.warning("Not registering trainable. Resolving as variant.")
            return run_object
        elif isinstance(run_object, type) or callable(run_object):
            name = "DEFAULT"
            if hasattr(run_object, "__name__"):
                name = run_object.__name__
            else:
                logger.warning(
                    "No name detected on trainable. Using {}.".format(name))
            register_trainable(name, run_object)
            return name
        else:
            raise TuneError("Improper 'run' - not string nor trainable.")
Ejemplo n.º 5
0
    def register_if_needed(cls, run_object):
        """Registers Trainable or Function at runtime.

        Assumes already registered if run_object is a string.
        Also, does not inspect interface of given run_object.

        Args:
            run_object (str|function|class): Trainable to run. If string,
                assumes it is an ID and does not modify it. Otherwise,
                returns a string corresponding to the run_object name.

        Returns:
            A string representing the trainable identifier.
        """
        if isinstance(run_object, str):
            return run_object
        elif isinstance(run_object, Domain):
            logger.warning("Not registering trainable. Resolving as variant.")
            return run_object
        name = cls.get_trainable_name(run_object)
        try:
            register_trainable(name, run_object)
        except (TypeError, PicklingError) as e:
            extra_msg = (
                "Other options: "
                "\n-Try reproducing the issue by calling "
                "`pickle.dumps(trainable)`. "
                "\n-If the error is typing-related, try removing "
                "the type annotations and try again."
            )
            raise type(e)(str(e) + " " + extra_msg) from None
        return name
Ejemplo n.º 6
0
    def _register_if_needed(cls, run_object):
        """Registers Trainable or Function at runtime.

        Assumes already registered if run_object is a string. Does not
        register lambdas because they could be part of variant generation.
        Also, does not inspect interface of given run_object.

        Arguments:
            run_object (str|function|class): Trainable to run. If string,
                assumes it is an ID and does not modify it. Otherwise,
                returns a string corresponding to the run_object name.

        Returns:
            A string representing the trainable identifier.
        """

        if isinstance(run_object, six.string_types):
            return run_object
        elif isinstance(run_object, types.FunctionType):
            if run_object.__name__ == "<lambda>":
                logger.warning(
                    "Not auto-registering lambdas - resolving as variant.")
                return run_object
            else:
                name = run_object.__name__
                register_trainable(name, run_object)
                return name
        elif isinstance(run_object, type):
            name = run_object.__name__
            register_trainable(name, run_object)
            return name
        else:
            raise TuneError("Improper 'run' - not string nor trainable.")
Ejemplo n.º 7
0
def _register_all():
    for key in ["PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "__fake",
                "__sigmoid_fake_data", "__parameter_tuning"]:
        try:
            from ray.rllib.agent import get_agent_class
            register_trainable(key, get_agent_class(key))
        except ImportError as e:
            print("Warning: could not import {}: {}".format(key, e))
Ejemplo n.º 8
0
def _register_all():

    from ray.rllib.agents.registry import ALGORITHMS
    from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
    for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys(
    )) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
        from ray.rllib.agents.registry import get_agent_class
        register_trainable(key, get_agent_class(key))
Ejemplo n.º 9
0
def _register_all():

    from ray.rllib.agents.registry import ALGORITHMS
    from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
    for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys(
    )) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
        from ray.rllib.agents.registry import get_agent_class
        register_trainable(key, get_agent_class(key))
Ejemplo n.º 10
0
def _register_all():
    for key in [
            "PPO", "ES", "DQN", "A3C", "BC", "__fake", "__sigmoid_fake_data"
    ]:
        try:
            register_trainable(key, get_agent_class(key))
        except ImportError as e:
            print("Warning: could not import {}: {}".format(key, e))
Ejemplo n.º 11
0
    def register_if_needed(cls, run_object):
        """Registers Trainable or Function at runtime.

        Assumes already registered if run_object is a string.
        Also, does not inspect interface of given run_object.

        Arguments:
            run_object (str|function|class): Trainable to run. If string,
                assumes it is an ID and does not modify it. Otherwise,
                returns a string corresponding to the run_object name.

        Returns:
            A string representing the trainable identifier.
        """

        if isinstance(run_object, str):
            return run_object
        elif isinstance(run_object, Domain):
            logger.warning("Not registering trainable. Resolving as variant.")
            return run_object
        elif isinstance(run_object, type) or callable(run_object):
            name = "DEFAULT"
            if hasattr(run_object, "_name"):
                name = run_object._name
            elif hasattr(run_object, "__name__"):
                fn_name = run_object.__name__
                if fn_name == "<lambda>":
                    name = "lambda"
                elif fn_name.startswith("<"):
                    name = "DEFAULT"
                else:
                    name = fn_name
            else:
                logger.warning(
                    "No name detected on trainable. Using {}.".format(name))
            try:
                register_trainable(name, run_object)
            except (TypeError, PicklingError) as e:
                msg = (
                    f"{str(e)}. The trainable ({str(run_object)}) could not "
                    "be serialized, which is needed for parallel execution. "
                    "To diagnose the issue, try the following:\n\n"
                    "\t- Run `tune.utils.diagnose_serialization(trainable)` "
                    "to check if non-serializable variables are captured "
                    "in scope.\n"
                    "\t- Try reproducing the issue by calling "
                    "`pickle.dumps(trainable)`.\n"
                    "\t- If the error is typing-related, try removing "
                    "the type annotations and try again.\n\n"
                    "If you have any suggestions on how to improve "
                    "this error message, please reach out to the "
                    "Ray developers on github.com/ray-project/ray/issues/")
                raise type(e)(msg) from None
            return name
        else:
            raise TuneError("Improper 'run' - not string nor trainable.")
Ejemplo n.º 12
0
    def testHasResourcesForTrialWithCaching(self):
        pgm = _PlacementGroupManager()
        pgf1 = PlacementGroupFactory([{"CPU": self.head_cpus}])
        pgf2 = PlacementGroupFactory([{"CPU": self.head_cpus - 1}])

        executor = RayTrialExecutor(reuse_actors=True)
        executor._pg_manager = pgm
        executor.set_max_pending_trials(1)

        def train(config):
            yield 1
            yield 2
            yield 3
            yield 4

        register_trainable("resettable", train)

        trial1 = Trial("resettable", placement_group_factory=pgf1)
        trial2 = Trial("resettable", placement_group_factory=pgf1)
        trial3 = Trial("resettable", placement_group_factory=pgf2)

        assert executor.has_resources_for_trial(trial1)
        assert executor.has_resources_for_trial(trial2)
        assert executor.has_resources_for_trial(trial3)

        executor._stage_and_update_status([trial1, trial2, trial3])

        while not pgm.has_ready(trial1):
            time.sleep(1)
            executor._stage_and_update_status([trial1, trial2, trial3])

        # Fill staging
        executor._stage_and_update_status([trial1, trial2, trial3])

        assert executor.has_resources_for_trial(trial1)
        assert executor.has_resources_for_trial(trial2)
        assert not executor.has_resources_for_trial(trial3)

        executor._start_trial(trial1)
        executor._stage_and_update_status([trial1, trial2, trial3])
        executor.pause_trial(
            trial1)  # Caches the PG and removes a PG from staging

        assert len(pgm._staging_futures) == 0

        # This will re-schedule a placement group
        pgm.reconcile_placement_groups([trial1, trial2])

        assert len(pgm._staging_futures) == 1

        assert not pgm.can_stage()

        # We should still have resources for this trial as it has a cached PG
        assert executor.has_resources_for_trial(trial1)
        assert executor.has_resources_for_trial(trial2)
        assert not executor.has_resources_for_trial(trial3)
Ejemplo n.º 13
0
def _register_all():
    for key in [
            "PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "__fake",
            "__sigmoid_fake_data", "__parameter_tuning"
    ]:
        try:
            from ray.rllib.agent import get_agent_class
            register_trainable(key, get_agent_class(key))
        except ImportError as e:
            print("Warning: could not import {}: {}".format(key, e))
Ejemplo n.º 14
0
def load_algorithms(CUSTOM_ALGORITHMS):
    """
    This function loads the custom algorithms implemented in this 
    repository, and registers them with the tune registry
    """
    from ray.tune import registry

    for _custom_algorithm_name in CUSTOM_ALGORITHMS:
        _class = CUSTOM_ALGORITHMS[_custom_algorithm_name]()
        registry.register_trainable(_custom_algorithm_name, _class)
Ejemplo n.º 15
0
    def register_if_needed(cls, run_object):
        """Registers Trainable or Function at runtime.

        Assumes already registered if run_object is a string.
        Also, does not inspect interface of given run_object.

        Arguments:
            run_object (str|function|class): Trainable to run. If string,
                assumes it is an ID and does not modify it. Otherwise,
                returns a string corresponding to the run_object name.

        Returns:
            A string representing the trainable identifier.
        """

        if isinstance(run_object, str):
            return run_object
        elif isinstance(run_object, Domain):
            logger.warning("Not registering trainable. Resolving as variant.")
            return run_object
        elif isinstance(run_object, type) or callable(run_object):
            name = "DEFAULT"
            if hasattr(run_object, "_name"):
                name = run_object._name
            elif hasattr(run_object, "__name__"):
                fn_name = run_object.__name__
                if fn_name == "<lambda>":
                    name = "lambda"
                elif fn_name.startswith("<"):
                    name = "DEFAULT"
                else:
                    name = fn_name
            elif (
                isinstance(run_object, partial)
                and hasattr(run_object, "func")
                and hasattr(run_object.func, "__name__")
            ):
                name = run_object.func.__name__
            else:
                logger.warning("No name detected on trainable. Using {}.".format(name))
            try:
                register_trainable(name, run_object)
            except (TypeError, PicklingError) as e:
                extra_msg = (
                    "Other options: "
                    "\n-Try reproducing the issue by calling "
                    "`pickle.dumps(trainable)`. "
                    "\n-If the error is typing-related, try removing "
                    "the type annotations and try again."
                )
                raise type(e)(str(e) + " " + extra_msg) from None
            return name
        else:
            raise TuneError("Improper 'run' - not string nor trainable.")
Ejemplo n.º 16
0
def AdaptDLTrainableCreator(func: Callable,
                            num_workers: int = 1,
                            group: int = 0,
                            num_cpus_per_worker: int = 1,
                            num_workers_per_host: Optional[int] = None,
                            backend: str = "gloo",
                            timeout_s: int = NCCL_TIMEOUT_S,
                            use_gpu=None):
    """ Trainable creator for AdaptDL's elastic Trials"""
    if config.default_device() == "GPU":
        backend = "nccl"

    class AdaptDLTrainable(_TorchTrainable):
        """ Similar to DistributedTrainable but for AdaptDLTrials."""
        def setup(self, config: Dict):
            """ Delay-patch methods when the Trainable actors are first
            created"""
            with patch(target="ray.tune.integration.torch.setup_process_group",
                       new=P.setup_process_group), \
                 patch(target='ray.tune.integration.torch.wrap_function',
                       new=P.wrap_function_patched):
                _TorchTrainable.setup(self, config)

        # Override the default resources and use custom PG factory
        @classmethod
        def default_resource_request(cls, config: Dict) -> Resources:
            return None

        def get_sched_hints(self):
            return ray.get(self.workers[0].get_sched_hints.remote())

        def save_all_states(self, trial_state):
            return ray.get(self.workers[0].save_all_states.remote(trial_state))

        @classmethod
        def default_process_group_parameters(self) -> Dict:
            return dict(timeout=timedelta(timeout_s), backend=backend)

    AdaptDLTrainable._function = func
    AdaptDLTrainable._num_workers = num_workers
    # Set number of GPUs if we're using them, this is later used when spawning
    # the trial actors
    if config.default_device() == "GPU":
        AdaptDLTrainable._num_gpus_per_worker = 1
    else:
        AdaptDLTrainable._num_gpus_per_worker = 0

    # Trainables are named after number of replicas they spawn. This is
    # essential to associate the right Trainable with the right Trial and PG.
    AdaptDLTrainable.__name__ = AdaptDLTrainable.__name__.split("_")[0] + \
        f"_{num_workers}" + f"_{group}"
    register_trainable(AdaptDLTrainable.__name__, AdaptDLTrainable)
    return AdaptDLTrainable
Ejemplo n.º 17
0
def main():
    register_doom_envs_rllib()
    register_dmlab_envs_rllib()

    ModelCatalog.register_custom_model('vizdoom_vision_model',
                                       VizdoomVisionNetwork)

    def custom_ppo():
        return PPOTrainer.with_updates(default_policy=CustomPPOTFPolicy)

    def custom_appo():
        return APPOTrainer.with_updates(
            default_policy=CustomAPPOTFPolicy,
            get_policy_class=lambda _: CustomAPPOTFPolicy,
        )

    register_trainable('CUSTOM_PPO', custom_ppo())
    register_trainable('CUSTOM_APPO', custom_appo())

    parser = create_parser()
    args = parser.parse_args()
    run_experiment(args, parser)
Ejemplo n.º 18
0
def _register_all():
    from ray.rllib.algorithms.algorithm import Algorithm
    from ray.rllib.algorithms.registry import ALGORITHMS, get_algorithm_class
    from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS

    for key in (list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys()) +
                ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]):
        register_trainable(key, get_algorithm_class(key))

    def _see_contrib(name):
        """Returns dummy agent class warning algo is in contrib/."""
        class _SeeContrib(Algorithm):
            def setup(self, config):
                raise NameError(
                    "Please run `contrib/{}` instead.".format(name))

        return _SeeContrib

    # Also register the aliases minus contrib/ to give a good error message.
    for key in list(CONTRIBUTED_ALGORITHMS.keys()):
        assert key.startswith("contrib/")
        alias = key.split("/", 1)[1]
        if alias not in ALGORITHMS:
            register_trainable(alias, _see_contrib(alias))
Ejemplo n.º 19
0
def main(args):
    ray.init()
    register_trainable("MADDPG", MADDPGTrainer)

    # Create test environment.
    env = parallel_env(simple_spread_v2)()
    # Register env
    env_name = 'simple_spread'
    register_env(env_name, lambda _: parallel_env(simple_spread_v2)())

    def gen_policy(i):
        use_local_critic = [
            args.adv_policy == "ddpg" if i < args.num_adversaries else
            args.good_policy == "ddpg" for i, _ in enumerate(env.agents)
        ]
        return (
            None,
            env.observation_space,
            env.action_space,
            {
                "agent_id": i,
                "use_local_critic": use_local_critic[i],
                "obs_space_dict": dict(zip([0, 1, 2], [env.observation_space, env.observation_space, env.observation_space])),
                "act_space_dict": dict(zip([0, 1, 2], [env.action_space, env.action_space, env.action_space])),
            }
        )

    policies = {agent: gen_policy(i) for i, agent in enumerate(env.agents)}

    ray.tune.run(
            "MADDPG",
            stop={
                "episodes_total": args.num_episodes,
            },
            checkpoint_at_end=True,
            checkpoint_freq=args.checkpoint_freq,
            local_dir=args.local_dir,
            # restore='/home/jiekaijia/PycharmProjects/pettingzoo_comunication/ray_results/MADDPG/MADDPG_mpe_2ff0a_00000_0_2021-06-07_14-23-38/checkpoint_000119/checkpoint-119',
            # args.restore,
            config={
                # === Log ===
                "log_level": "ERROR",
                'render_env': True,

                # === Environment ===
                'env': env_name,
                # "env_config": {'max_cycles': 25, 'num_agents': 3, 'local_ratio': 0.5},
                "num_envs_per_worker": args.num_envs_per_worker,
                "horizon": args.max_episode_len,

                # === Policy Config ===
                # --- Model ---
                "good_policy": args.good_policy,
                "adv_policy": args.adv_policy,
                "actor_hiddens": [args.num_units] * 2,
                "actor_hidden_activation": "relu",
                "critic_hiddens": [args.num_units] * 2,
                "critic_hidden_activation": "relu",
                "n_step": args.n_step,
                "gamma": args.gamma,

                # --- Exploration ---
                "tau": 0.01,

                # --- Replay buffer ---
                "buffer_size": int(1e6),

                # --- Optimization ---
                "actor_lr": args.lr,
                "critic_lr": args.lr,
                "learning_starts": args.train_batch_size * args.max_episode_len,
                "rollout_fragment_length": args.sample_batch_size,
                "train_batch_size": args.train_batch_size,
                "batch_mode": "truncate_episodes",

                # --- Parallelism ---
                "num_workers": args.num_workers,
                "num_gpus": int(os.environ.get('RLLIB_NUM_GPUS', '0')),
                # "num_gpus_per_worker": 0,

                # === Multi-agent setting ===
                "multiagent": {
                    "policies": policies,
                    "policy_mapping_fn": lambda agent_id: agent_id
                }}
    )
Ejemplo n.º 20
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import ray
from ray.tune.registry import register_trainable, register_env
from ray.tune import run_experiments, grid_search
from ray.rllib.models.catalog import ModelCatalog

from maml import MAMLAgent
from point_env import PointEnv
from reset_wrapper import ResetWrapper
from fcnet import FullyConnectedNetwork

register_trainable("MAML", MAMLAgent)
env_cls = PointEnv
register_env(env_cls.__name__,
             lambda env_config: ResetWrapper(env_cls(), env_config))
ModelCatalog.register_custom_model("maml_mlp", FullyConnectedNetwork)

# ray.init()
ray.init(redis_address="localhost:32222")

config = {
    "random_seed": grid_search([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
    "inner_lr": grid_search([0.01]),
    "inner_grad_clip": grid_search([10.0, 20.0, 30.0, 40.0]),
    "clip_param": grid_search([0.1, 0.2, 0.3]),
    "vf_loss_coeff": grid_search([0.01, 0.02, 0.05, 0.1, 0.2]),
    "vf_clip_param": grid_search([5.0, 10.0, 15.0, 20.0]),
Ejemplo n.º 21
0
def main(args):
    #ray.init(redis_max_memory=int(1e10), object_store_memory=int(3e9))
    #memory=int(6200000000)
    #ray.init(memory=int(4200000000), object_store_memory=int(2200000000), num_gpus=1, num_cpus=6)
    """
    ray.init(redis_max_memory=int(ray.utils.get_system_memory() * 0.4),
             memory=int(ray.utils.get_system_memory() * 0.2),
             object_store_memory=int(ray.utils.get_system_memory() * 0.2),
             huge_pages=False,
             num_gpus=1,
             num_cpus=6,
             temp_dir='/mnt/hdd-a500/Ray_temp/')
    """
    ray.init(
        redis_max_memory=int(ray.utils.get_system_memory() * 0.4),
        memory=int(ray.utils.get_system_memory() * 0.2),
        object_store_memory=int(ray.utils.get_system_memory() * 0.2),
        # huge_pages=False,
        num_gpus=args.num_gpus,
        num_cpus=6,
        temp_dir=args.temp_dir)

    MADDPGAgent = MADDPGTrainer.with_updates(mixins=[CustomStdOut])
    register_trainable("MADDPG", MADDPGAgent)

    def env_creater(mpe_args):
        return MultiAgentParticleEnv(**mpe_args)

    register_env("mpe", env_creater)

    env = env_creater({
        "scenario_name": args.scenario,
    })

    def gen_policy(i):
        use_local_critic = [
            args.adv_policy == "ddpg"
            if i < args.num_adversaries else args.good_policy == "ddpg"
            for i in range(env.num_agents)
        ]
        return (None, env.observation_space_dict[i], env.action_space_dict[i],
                {
                    "agent_id": i,
                    "use_local_critic": use_local_critic[i],
                    "obs_space_dict": env.observation_space_dict,
                    "act_space_dict": env.action_space_dict,
                })

    policies = {
        "policy_%d" % i: gen_policy(i)
        for i in range(len(env.observation_space_dict))
    }
    policy_ids = list(policies.keys())

    def policy_mapping_fn(agent_id):
        return policy_ids[agent_id]

    exp_name = "{}{}".format(
        args.scenario.replace("_", "").replace("-", ""),
        "_{}".format(args.add_postfix) if args.add_postfix != "" else "")
    run_experiments(
        {
            exp_name: {
                "run": "contrib/MADDPG",
                "env": "mpe",
                "stop": {
                    "episodes_total": args.num_episodes,
                },
                "checkpoint_freq": args.checkpoint_freq,
                "local_dir": args.local_dir,
                "restore": args.restore,
                "config": {
                    # === Log ===
                    "log_level": "ERROR",

                    # === Environment ===
                    "env_config": {
                        "scenario_name": args.scenario,
                    },
                    "num_envs_per_worker": args.num_envs_per_worker,
                    "horizon": args.max_episode_len,

                    # === Policy Config ===
                    # --- Model ---
                    "good_policy": args.good_policy,
                    "adv_policy": args.adv_policy,
                    "actor_hiddens": [args.num_units] * 2,
                    "actor_hidden_activation": "relu",
                    "critic_hiddens": [args.num_units] * 2,
                    "critic_hidden_activation": "relu",
                    "n_step": args.n_step,
                    "gamma": args.gamma,

                    # --- Exploration ---
                    "tau": args.tau,

                    # --- Replay buffer ---
                    "buffer_size":
                    args.replay_buffer,  # int(10000), # int(1e6)

                    # --- Optimization ---
                    "actor_lr": args.lr,
                    "critic_lr": args.lr,
                    "learning_starts":
                    args.train_batch_size * args.max_episode_len,
                    "sample_batch_size": args.sample_batch_size,
                    "train_batch_size": args.train_batch_size,
                    "batch_mode": "truncate_episodes",

                    # --- Parallelism ---
                    "num_workers": args.num_workers,
                    "num_gpus": args.num_gpus,
                    "num_gpus_per_worker": 0,

                    # === Multi-agent setting ===
                    "multiagent": {
                        "policies": policies,
                        "policy_mapping_fn":
                        ray.tune.function(policy_mapping_fn)
                    },
                },
            },
        },
        verbose=0,
        reuse_actors=False)  # reuse_actors=True - messes up the results
Ejemplo n.º 22
0
    _policy_graph = filter_var_policy_factory(scope_not_to_freeze='film')


class VisionFrozenApexLoadedWeight(dqn.ApexAgent):
    _agent_name = "APEX_VISION_FROZEN_LOADED_WEIGHT"
    _default_config = dqn.apex.APEX_DEFAULT_CONFIG

    _policy_graph = loading_class_factory(
        filter_var_policy_factory(scope_not_to_freeze='film'),
        'test_weight.npy')


class FilmFrozenApexLoadedWeight(dqn.ApexAgent):
    _agent_name = "APEX_FILM_FROZEN_LOADED_WEIGHT"
    _default_config = dqn.apex.APEX_DEFAULT_CONFIG

    _policy_graph = loading_class_factory(
        filter_var_policy_factory(scope_to_freeze='film'), 'test_weight.npy')


# Frozen GraphModel
register_trainable("apex_vision_frozen", VisionFrozenApex)
register_trainable("apex_film_frozen", FilmFrozenApex)

# Model loading
register_trainable("apex_vision_frozen_loaded_weight",
                   VisionFrozenApexLoadedWeight)
register_trainable("apex_film_frozen_loaded_weight",
                   FilmFrozenApexLoadedWeight)
Ejemplo n.º 23
0
def _register_all():
    register_trainable("PPO", ppo.PPOAgent)
    register_trainable("ES", es.ESAgent)
    register_trainable("DQN", dqn.DQNAgent)
    register_trainable("A3C", a3c.A3CAgent)
    register_trainable("__fake", _MockAgent)
    register_trainable("__sigmoid_fake_data", _SigmoidFakeData)
Ejemplo n.º 24
0
if alg_name == "maml":
    agent_cls = MAMLAgent
elif alg_name == "meta-sgd":
    agent_cls = MetaSGDAgent
elif alg_name == "maesn":
    agent_cls = MAESNAgent
elif alg_name == "tesp":
    agent_cls = TESPAgent

all_agent_cls = [agent_cls]
all_model_cls = [RLlibMLP, RLlibMAESN, RLlibTESP, RLlibTESPWithAdapPolicy]

register_env(env_cls.__name__,
             lambda env_config: ResetWrapper(env_cls(env_config), env_config))
for agent_cls in all_agent_cls:
    register_trainable(agent_cls._agent_name, agent_cls)
for model_cls in all_model_cls:
    ModelCatalog.register_custom_model(model_cls.__name__, model_cls)


def get_config(model_cls_name):
    config = {
        "random_seed": grid_search([1]),
        "inner_lr": grid_search([0.001]),
        "inner_lr_bound": 0.1,
        "inner_grad_clip": 60.0,
        "num_inner_updates": 3,
        "outer_lr": grid_search([3e-4]),
        "num_sgd_iter": 20,
        "clip_param": 0.15,
        "model_loss_coeff": grid_search([0.01]),
Ejemplo n.º 25
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from ray.tune.error import TuneError
from ray.tune.tune import run_experiments
from ray.tune.registry import register_env, register_trainable
from ray.tune.result import TrainingResult
from ray.tune.script_runner import ScriptRunner
from ray.tune.trainable import Trainable
from ray.tune.variant_generator import grid_search

register_trainable("script", ScriptRunner)

__all__ = [
    "Trainable",
    "TrainingResult",
    "TuneError",
    "grid_search",
    "register_env",
    "register_trainable",
    "run_experiments",
]
Ejemplo n.º 26
0
def diagnose_serialization(trainable):
    """Utility for detecting why your trainable function isn't serializing.

    Args:
        trainable (func): The trainable object passed to
            tune.run(trainable). Currently only supports
            Function API.

    Returns:
        bool | set of unserializable objects.

    Example:

    .. code-block:: python

        import threading
        # this is not serializable
        e = threading.Event()

        def test():
            print(e)

        diagnose_serialization(test)
        # should help identify that 'e' should be moved into
        # the `test` scope.

        # correct implementation
        def test():
            e = threading.Event()
            print(e)

        assert diagnose_serialization(test) is True

    """
    from ray.tune.registry import register_trainable, check_serializability

    def check_variables(objects, failure_set, printer):
        for var_name, variable in objects.items():
            msg = None
            try:
                check_serializability(var_name, variable)
                status = "PASSED"
            except Exception as e:
                status = "FAILED"
                msg = f"{e.__class__.__name__}: {str(e)}"
                failure_set.add(var_name)
            printer(f"{str(variable)}[name='{var_name}'']... {status}")
            if msg:
                printer(msg)

    print(f"Trying to serialize {trainable}...")
    try:
        register_trainable("__test:" + str(trainable), trainable, warn=False)
        print("Serialization succeeded!")
        return True
    except Exception as e:
        print(f"Serialization failed: {e}")

    print("Inspecting the scope of the trainable by running "
          f"`inspect.getclosurevars({str(trainable)})`...")
    closure = inspect.getclosurevars(trainable)
    failure_set = set()
    if closure.globals:
        print(f"Detected {len(closure.globals)} global variables. "
              "Checking serializability...")
        check_variables(closure.globals, failure_set,
                        lambda s: print("   " + s))

    if closure.nonlocals:
        print(f"Detected {len(closure.nonlocals)} nonlocal variables. "
              "Checking serializability...")
        check_variables(closure.nonlocals, failure_set,
                        lambda s: print("   " + s))

    if not failure_set:
        print("Nothing was found to have failed the diagnostic test, though "
              "serialization did not succeed. Feel free to raise an "
              "issue on github.")
        return failure_set
    else:
        print(f"Variable(s) {failure_set} was found to be non-serializable. "
              "Consider either removing the instantiation/imports "
              "of these objects or moving them into the scope of "
              "the trainable. ")
        return failure_set
Ejemplo n.º 27
0
def main(args):
    ray.init(redis_max_memory=int(1e10), object_store_memory=int(3e9))
    MADDPGAgent = maddpg.MADDPGTrainer.with_updates(mixins=[CustomStdOut])
    register_trainable("MADDPG", MADDPGAgent)

    def env_creater(mpe_args):
        return MultiAgentParticleEnv(**mpe_args)

    register_env("mpe", env_creater)

    env = env_creater({
        "scenario_name": args.scenario,
    })

    def gen_policy(i):
        use_local_critic = [
            args.adv_policy == "ddpg"
            if i < args.num_adversaries else args.good_policy == "ddpg"
            for i in range(env.num_agents)
        ]
        return (None, env.observation_space_dict[i], env.action_space_dict[i],
                {
                    "agent_id": i,
                    "use_local_critic": use_local_critic[i],
                    "obs_space_dict": env.observation_space_dict,
                    "act_space_dict": env.action_space_dict,
                })

    policies = {
        "policy_%d" % i: gen_policy(i)
        for i in range(len(env.observation_space_dict))
    }
    policy_ids = list(policies.keys())

    run_experiments(
        {
            "MADDPG_RLLib": {
                "run": "contrib/MADDPG",
                "env": "mpe",
                "stop": {
                    "episodes_total": args.num_episodes,
                },
                "checkpoint_freq": args.checkpoint_freq,
                "local_dir": args.local_dir,
                "restore": args.restore,
                "config": {
                    # === Log ===
                    "log_level": "ERROR",

                    # === Environment ===
                    "env_config": {
                        "scenario_name": args.scenario,
                    },
                    "num_envs_per_worker": args.num_envs_per_worker,
                    "horizon": args.max_episode_len,

                    # === Policy Config ===
                    # --- Model ---
                    "good_policy": args.good_policy,
                    "adv_policy": args.adv_policy,
                    "actor_hiddens": [args.num_units] * 2,
                    "actor_hidden_activation": "relu",
                    "critic_hiddens": [args.num_units] * 2,
                    "critic_hidden_activation": "relu",
                    "n_step": args.n_step,
                    "gamma": args.gamma,

                    # --- Exploration ---
                    "tau": 0.01,

                    # --- Replay buffer ---
                    "buffer_size": args.replay_buffer,

                    # --- Optimization ---
                    "actor_lr": args.lr,
                    "critic_lr": args.lr,
                    "learning_starts":
                    args.train_batch_size * args.max_episode_len,
                    "sample_batch_size": args.sample_batch_size,
                    "train_batch_size": args.train_batch_size,
                    "batch_mode": "truncate_episodes",

                    # --- Parallelism ---
                    "num_workers": args.num_workers,
                    "num_gpus": args.num_gpus,
                    "num_gpus_per_worker": 0,

                    # === Multi-agent setting ===
                    "multiagent": {
                        "policies":
                        policies,
                        "policy_mapping_fn":
                        ray.tune.function(lambda i: policy_ids[i])
                    },
                },
            },
        },
        verbose=0)