Exemple #1
0
    def test_pettingzoo_env(self):
        register_env("simple_spread",
                     lambda _: PettingZooEnv(simple_spread_v2.env()))
        env = PettingZooEnv(simple_spread_v2.env())
        observation_space = env.observation_space
        action_space = env.action_space
        del env

        agent_class = get_algorithm_class("PPO")

        config = deepcopy(agent_class.get_default_config())

        config["multiagent"] = {
            # Set of policy IDs (by default, will use Trainer's
            # default policy class, the env's obs/act spaces and config={}).
            "policies": {
                "av": (None, observation_space, action_space, {})
            },
            # Mapping function that always returns "av" as policy ID to use
            # (for any agent).
            "policy_mapping_fn": lambda agent_id, episode, **kwargs: "av",
        }

        config["log_level"] = "DEBUG"
        config["num_workers"] = 0
        config["rollout_fragment_length"] = 30
        config["train_batch_size"] = 200
        config["horizon"] = 200  # After n steps, force reset simulation
        config["no_done_at_end"] = False

        agent = agent_class(env="simple_spread", config=config)
        agent.train()
def check_support(alg, config, test_eager=False, test_trace=True):
    config["framework"] = "tfe"
    config["log_level"] = "ERROR"
    # Test both continuous and discrete actions.
    for cont in [True, False]:
        if cont and alg in ["DQN", "APEX", "SimpleQ"]:
            continue
        elif not cont and alg in ["DDPG", "APEX_DDPG", "TD3"]:
            continue

        if cont:
            config["env"] = "Pendulum-v1"
        else:
            config["env"] = "CartPole-v0"

        a = get_algorithm_class(alg)
        if test_eager:
            print("tf-eager: alg={} cont.act={}".format(alg, cont))
            config["eager_tracing"] = False
            tune.run(a,
                     config=config,
                     stop={"training_iteration": 1},
                     verbose=1)
        if test_trace:
            config["eager_tracing"] = True
            print("tf-eager-tracing: alg={} cont.act={}".format(alg, cont))
            tune.run(a,
                     config=config,
                     stop={"training_iteration": 1},
                     verbose=1)
Exemple #3
0
    def _do_test_fault_fatal(self, alg, config, eval_only=False):
        register_env("fault_env", lambda c: FaultInjectEnv(c))
        agent_cls = get_algorithm_class(alg)

        # Test raises real error when out of workers.
        if not eval_only:
            config["num_workers"] = 2
            config["ignore_worker_failures"] = True
            # Make both worker idx=1 and 2 fail.
            config["env_config"] = {"bad_indices": [1, 2]}
        else:
            config["num_workers"] = 1
            config["evaluation_num_workers"] = 1
            config["evaluation_interval"] = 1
            config["evaluation_config"] = {
                "ignore_worker_failures": True,
                # Make eval worker (index 1) fail.
                "env_config": {
                    "bad_indices": [1],
                },
            }

        for _ in framework_iterator(config, frameworks=("torch", "tf")):
            a = agent_cls(config=config, env="fault_env")
            self.assertRaises(Exception, lambda: a.train())
            a.stop()
Exemple #4
0
    def _do_test_fault_ignore(self,
                              algo: str,
                              config: dict,
                              eval_only: bool = False):
        register_env("fault_env", lambda c: FaultInjectEnv(c))
        algo_cls = get_algorithm_class(algo)

        # Test fault handling
        if not eval_only:
            config["num_workers"] = 2
            config["ignore_worker_failures"] = True
            # Make worker idx=1 fail. Other workers will be ok.
            config["env_config"] = {"bad_indices": [1]}
        else:
            config["num_workers"] = 1
            config["evaluation_num_workers"] = 2
            config["evaluation_interval"] = 1
            config["evaluation_config"] = {
                "ignore_worker_failures": True,
                "env_config": {
                    # Make worker idx=1 fail. Other workers will be ok.
                    "bad_indices": [1],
                },
            }

        for _ in framework_iterator(config, frameworks=("tf2", "torch")):
            algo = algo_cls(config=config, env="fault_env")
            result = algo.train()
            if not eval_only:
                self.assertTrue(result["num_healthy_workers"] == 1)
            else:
                self.assertTrue(result["num_healthy_workers"] == 1)
                self.assertTrue(
                    result["evaluation"]["num_healthy_workers"] == 1)
            algo.stop()
    def _do_test_fault_fatal_but_recreate(self, alg, config):
        register_env("fault_env", lambda c: FaultInjectEnv(c))
        agent_cls = get_algorithm_class(alg)

        # Test raises real error when out of workers
        config["num_workers"] = 2
        config["recreate_failed_workers"] = True
        # Make both worker idx=1 and 2 fail.
        config["env_config"] = {"bad_indices": [1, 2]}

        for _ in framework_iterator(config, frameworks=("tf2", "torch")):
            a = agent_cls(config=config, env="fault_env")
            # Expect this to go well and all faulty workers are recovered.
            self.assertTrue(not any(
                ray.get(
                    worker.apply.remote(lambda w: w.recreated_worker or w.
                                        env_context.recreated_worker))
                for worker in a.workers.remote_workers()))
            result = a.train()
            self.assertTrue(result["num_healthy_workers"], 2)
            self.assertTrue(
                all(
                    ray.get(
                        worker.apply.remote(lambda w: w.recreated_worker and w.
                                            env_context.recreated_worker))
                    for worker in a.workers.remote_workers()))
            # This should also work several times.
            result = a.train()
            self.assertTrue(result["num_healthy_workers"], 2)
            a.stop()
def train_and_export(algo_name, num_steps, model_dir, ckpt_dir, prefix):
    cls = get_algorithm_class(algo_name)
    alg = cls(config={}, env="CartPole-v0")
    for _ in range(num_steps):
        alg.train()

    # Export tensorflow checkpoint for fine-tuning
    alg.export_policy_checkpoint(ckpt_dir, filename_prefix=prefix)
    # Export tensorflow SavedModel for online serving
    alg.export_policy_model(model_dir)
    def _do_check(alg, config, a_name, o_name):
        fw = config["framework"]
        action_space = ACTION_SPACES_TO_TEST[a_name]
        obs_space = OBSERVATION_SPACES_TO_TEST[o_name]
        print("=== Testing {} (fw={}) A={} S={} ===".format(
            alg, fw, action_space, obs_space))
        config.update(
            dict(env_config=dict(
                action_space=action_space,
                observation_space=obs_space,
                reward_space=Box(1.0, 1.0, shape=(), dtype=np.float32),
                p_done=1.0,
                check_action_bounds=check_bounds,
            )))
        stat = "ok"

        try:
            a = get_algorithm_class(alg)(config=config, env=RandomEnv)
        except ray.exceptions.RayActorError as e:
            if len(e.args) >= 2 and isinstance(e.args[2],
                                               UnsupportedSpaceException):
                stat = "unsupported"
            elif isinstance(e.args[0].args[2], UnsupportedSpaceException):
                stat = "unsupported"
            else:
                raise
        except UnsupportedSpaceException:
            stat = "unsupported"
        else:
            if alg not in ["DDPG", "ES", "ARS", "SAC"]:
                # 2D (image) input: Expect VisionNet.
                if o_name in ["atari", "image"]:
                    if fw == "torch":
                        assert isinstance(a.get_policy().model, TorchVisionNet)
                    else:
                        assert isinstance(a.get_policy().model, VisionNet)
                # 1D input: Expect FCNet.
                elif o_name == "vector1d":
                    if fw == "torch":
                        assert isinstance(a.get_policy().model, TorchFCNet)
                    else:
                        assert isinstance(a.get_policy().model, FCNet)
                # Could be either one: ComplexNet (if disabled Preprocessor)
                # or FCNet (w/ Preprocessor).
                elif o_name == "vector2d":
                    if fw == "torch":
                        assert isinstance(a.get_policy().model,
                                          (TorchComplexNet, TorchFCNet))
                    else:
                        assert isinstance(a.get_policy().model,
                                          (ComplexNet, FCNet))
            if train:
                a.train()
            a.stop()
        print(stat)
Exemple #8
0
def export_test(alg_name, failures, framework="tf"):
    def valid_tf_model(model_dir):
        return os.path.exists(os.path.join(model_dir, "saved_model.pb")) and os.listdir(
            os.path.join(model_dir, "variables")
        )

    def valid_tf_checkpoint(checkpoint_dir):
        return (
            os.path.exists(os.path.join(checkpoint_dir, "model.meta"))
            and os.path.exists(os.path.join(checkpoint_dir, "model.index"))
            and os.path.exists(os.path.join(checkpoint_dir, "checkpoint"))
        )

    cls = get_algorithm_class(alg_name)
    config = CONFIGS[alg_name].copy()
    config["framework"] = framework
    if "DDPG" in alg_name or "SAC" in alg_name:
        algo = cls(config=config, env="Pendulum-v1")
    else:
        algo = cls(config=config, env="CartPole-v0")

    for _ in range(1):
        res = algo.train()
        print("current status: " + str(res))

    export_dir = os.path.join(
        ray._private.utils.get_user_temp_dir(), "export_dir_%s" % alg_name
    )
    print("Exporting model ", alg_name, export_dir)
    algo.export_policy_model(export_dir)
    if framework == "tf" and not valid_tf_model(export_dir):
        failures.append(alg_name)
    shutil.rmtree(export_dir)

    if framework == "tf":
        print("Exporting checkpoint", alg_name, export_dir)
        algo.export_policy_checkpoint(export_dir)
        if framework == "tf" and not valid_tf_checkpoint(export_dir):
            failures.append(alg_name)
        shutil.rmtree(export_dir)

        print("Exporting default policy", alg_name, export_dir)
        algo.export_model([ExportFormat.CHECKPOINT, ExportFormat.MODEL], export_dir)
        if not valid_tf_model(
            os.path.join(export_dir, ExportFormat.MODEL)
        ) or not valid_tf_checkpoint(os.path.join(export_dir, ExportFormat.CHECKPOINT)):
            failures.append(alg_name)

        # Test loading the exported model.
        model = tf.saved_model.load(os.path.join(export_dir, ExportFormat.MODEL))
        assert model

        shutil.rmtree(export_dir)
    algo.stop()
Exemple #9
0
    def test_pettingzoo_pistonball_v6_policies_are_dict_env(self):
        def env_creator(config):
            env = pistonball_v6.env()
            env = dtype_v0(env, dtype=float32)
            env = color_reduction_v0(env, mode="R")
            env = normalize_obs_v0(env)
            return env

        config = deepcopy(get_algorithm_class("PPO").get_default_config())
        config["env_config"] = {"local_ratio": 0.5}
        # Register env
        register_env("pistonball",
                     lambda config: PettingZooEnv(env_creator(config)))
        env = PettingZooEnv(env_creator(config))
        observation_space = env.observation_space
        action_space = env.action_space
        del env

        config["multiagent"] = {
            # Setup a single, shared policy for all agents.
            "policies": {
                "av": (None, observation_space, action_space, {})
            },
            # Map all agents to that policy.
            "policy_mapping_fn": lambda agent_id, episode, **kwargs: "av",
        }

        config["log_level"] = "DEBUG"
        config["num_workers"] = 1
        # Fragment length, collected at once from each worker
        # and for each agent!
        config["rollout_fragment_length"] = 30
        # Training batch size -> Fragments are concatenated up to this point.
        config["train_batch_size"] = 200
        # After n steps, force reset simulation
        config["horizon"] = 200
        # Default: False
        config["no_done_at_end"] = False
        algo = get_algorithm_class("PPO")(env="pistonball", config=config)
        algo.train()
        algo.stop()
    def _do_test_fault_fatal(self, alg, config):
        register_env("fault_env", lambda c: FaultInjectEnv(c))
        agent_cls = get_algorithm_class(alg)

        # Test raises real error when out of workers
        config["num_workers"] = 2
        config["ignore_worker_failures"] = True
        # Make both worker idx=1 and 2 fail.
        config["env_config"] = {"bad_indices": [1, 2]}

        for _ in framework_iterator(config, frameworks=("torch", "tf")):
            a = agent_cls(config=config, env="fault_env")
            self.assertRaises(Exception, lambda: a.train())
            a.stop()
def check_support_multiagent(alg, config):
    register_env("multi_agent_mountaincar",
                 lambda _: MultiAgentMountainCar({"num_agents": 2}))
    register_env("multi_agent_cartpole",
                 lambda _: MultiAgentCartPole({"num_agents": 2}))

    # Simulate a simple multi-agent setup.
    policies = {
        "policy_0": PolicySpec(config={"gamma": 0.99}),
        "policy_1": PolicySpec(config={"gamma": 0.95}),
    }
    policy_ids = list(policies.keys())

    def policy_mapping_fn(agent_id, episode, worker, **kwargs):
        pol_id = policy_ids[agent_id]
        return pol_id

    config["multiagent"] = {
        "policies": policies,
        "policy_mapping_fn": policy_mapping_fn,
    }

    for fw in framework_iterator(config):
        if fw in ["tf2", "tfe"
                  ] and alg in ["A3C", "APEX", "APEX_DDPG", "IMPALA"]:
            continue
        if alg in ["DDPG", "APEX_DDPG", "SAC"]:
            a = get_algorithm_class(alg)(config=config,
                                         env="multi_agent_mountaincar")
        else:
            a = get_algorithm_class(alg)(config=config,
                                         env="multi_agent_cartpole")

        results = a.train()
        check_train_results(results)
        print(results)
        a.stop()
    def _do_test_fault_ignore(self, algo: str, config: dict):
        register_env("fault_env", lambda c: FaultInjectEnv(c))
        algo_cls = get_algorithm_class(algo)

        # Test fault handling
        config["num_workers"] = 2
        config["ignore_worker_failures"] = True
        # Make worker idx=1 fail. Other workers will be ok.
        config["env_config"] = {"bad_indices": [1]}

        for _ in framework_iterator(config, frameworks=("tf2", "torch")):
            algo = algo_cls(config=config, env="fault_env")
            result = algo.train()
            self.assertTrue(result["num_healthy_workers"], 1)
            algo.stop()
def model_import_test(algo, config, env):
    # Get the abs-path to use (bazel-friendly).
    rllib_dir = Path(__file__).parent.parent
    import_file = str(rllib_dir) + "/tests/data/model_weights/weights.h5"

    agent_cls = get_algorithm_class(algo)

    for fw in framework_iterator(config, ["tf", "torch"]):
        config["model"]["custom_model"] = ("keras_model"
                                           if fw != "torch" else "torch_model")

        agent = agent_cls(config, env)

        def current_weight(agent):
            if fw == "tf":
                return agent.get_weights(
                )[DEFAULT_POLICY_ID]["default_policy/value/kernel"][0]
            elif fw == "torch":
                return float(agent.get_weights()[DEFAULT_POLICY_ID]
                             ["value_branch.weight"][0][0])
            else:
                return agent.get_weights()[DEFAULT_POLICY_ID][4][0]

        # Import weights for our custom model from an h5 file.
        weight_before_import = current_weight(agent)
        agent.import_model(import_file=import_file)
        weight_after_import = current_weight(agent)
        check(weight_before_import, weight_after_import, false=True)

        # Train for a while.
        for _ in range(1):
            agent.train()
        weight_after_train = current_weight(agent)
        # Weights should have changed.
        check(weight_before_import, weight_after_train, false=True)
        check(weight_after_import, weight_after_train, false=True)

        # We can save the entire Agent and restore, weights should remain the
        # same.
        file = agent.save("after_train")
        check(weight_after_train, current_weight(agent))
        agent.restore(file)
        check(weight_after_train, current_weight(agent))

        # Import (untrained) weights again.
        agent.import_model(import_file=import_file)
        check(current_weight(agent), weight_after_import)
Exemple #14
0
def _register_all():
    from ray.rllib.algorithms.algorithm import Algorithm
    from ray.rllib.algorithms.registry import ALGORITHMS, get_algorithm_class
    from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS

    for key in (list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys()) +
                ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]):
        register_trainable(key, get_algorithm_class(key))

    def _see_contrib(name):
        """Returns dummy agent class warning algo is in contrib/."""
        class _SeeContrib(Algorithm):
            def setup(self, config):
                raise NameError(
                    "Please run `contrib/{}` instead.".format(name))

        return _SeeContrib

    # Also register the aliases minus contrib/ to give a good error message.
    for key in list(CONTRIBUTED_ALGORITHMS.keys()):
        assert key.startswith("contrib/")
        alias = key.split("/", 1)[1]
        if alias not in ALGORITHMS:
            register_trainable(alias, _see_contrib(alias))
    print("Training policy until desired reward/timesteps/iterations. ...")
    results = tune.run(
        args.run,
        config=config,
        stop=stop,
        verbose=2,
        checkpoint_freq=1,
        checkpoint_at_end=True,
    )

    print("Training completed. Restoring new Trainer for action inference.")
    # Get the last checkpoint from the above training run.
    checkpoint = results.get_last_checkpoint()
    # Create new Trainer and restore its state from the last checkpoint.
    algo = get_algorithm_class(args.run)(config=config)
    algo.restore(checkpoint)

    # Create the env to do inference in.
    env = gym.make("FrozenLake-v1")
    obs = env.reset()

    num_episodes = 0
    episode_reward = 0.0

    while num_episodes < args.num_episodes_during_inference:
        # Compute an action (`a`).
        a = algo.compute_single_action(
            observation=obs,
            explore=args.explore_during_inference,
            policy_id="default_policy",  # <- default value
    # TRAIN
    results = tune.run("RNNSAC", **config)

    # TEST
    checkpoint_config_path = str(Path(results.best_logdir) / "params.json")
    with open(checkpoint_config_path, "rb") as f:
        checkpoint_config = json.load(f)

    checkpoint_config["explore"] = False

    best_checkpoint = results.best_checkpoint
    print("Loading checkpoint: {}".format(best_checkpoint))

    algo = get_algorithm_class("RNNSAC")(
        env=config["config"]["env"], config=checkpoint_config
    )
    algo.restore(best_checkpoint)

    env = algo.env_creator({})
    state = algo.get_policy().get_initial_state()
    prev_action = 0
    prev_reward = 0
    obs = env.reset()

    eps = 0
    ep_reward = 0
    while eps < 10:
        action, state, info_algo = algo.compute_single_action(
            obs,
            state=state,
Exemple #17
0
def run(args, parser):
    # Load configuration from checkpoint file.
    config_path = ""
    if args.checkpoint:
        config_dir = os.path.dirname(args.checkpoint)
        config_path = os.path.join(config_dir, "params.pkl")
        # Try parent directory.
        if not os.path.exists(config_path):
            config_path = os.path.join(config_dir, "../params.pkl")

    # Load the config from pickled.
    if os.path.exists(config_path):
        with open(config_path, "rb") as f:
            config = cloudpickle.load(f)
    # If no pkl file found, require command line `--config`.
    else:
        # If no config in given checkpoint -> Error.
        if args.checkpoint:
            raise ValueError(
                "Could not find params.pkl in either the checkpoint dir or "
                "its parent directory AND no `--config` given on command "
                "line!")

        # Use default config for given agent.
        _, config = get_algorithm_class(args.run, return_config=True)

    # Make sure worker 0 has an Env.
    config["create_env_on_driver"] = True

    # Merge with `evaluation_config` (first try from command line, then from
    # pkl file).
    evaluation_config = copy.deepcopy(
        args.config.get("evaluation_config",
                        config.get("evaluation_config", {})))
    config = merge_dicts(config, evaluation_config)
    # Merge with command line `--config` settings (if not already the same
    # anyways).
    config = merge_dicts(config, args.config)
    if not args.env:
        if not config.get("env"):
            parser.error("the following arguments are required: --env")
        args.env = config.get("env")

    # Make sure we have evaluation workers.
    if not config.get("evaluation_num_workers"):
        config["evaluation_num_workers"] = config.get("num_workers", 0)
    if not config.get("evaluation_duration"):
        config["evaluation_duration"] = 1
    # Hard-override this as it raises a warning by Trainer otherwise.
    # Makes no sense anyways, to have it set to None as we don't call
    # `Trainer.train()` here.
    config["evaluation_interval"] = 1

    # Rendering and video recording settings.
    if args.no_render:
        deprecation_warning(old="--no-render", new="--render", error=False)
        args.render = False
    config["render_env"] = args.render

    ray.init(local_mode=args.local_mode)

    # Create the Trainer from config.
    cls = get_trainable_cls(args.run)
    agent = cls(env=args.env, config=config)

    # Load state from checkpoint, if provided.
    if args.checkpoint:
        agent.restore(args.checkpoint)

    num_steps = int(args.steps)
    num_episodes = int(args.episodes)

    # Do the actual rollout.
    with RolloutSaver(
            args.out,
            args.use_shelve,
            write_update_file=args.track_progress,
            target_steps=num_steps,
            target_episodes=num_episodes,
            save_info=args.save_info,
    ) as saver:
        rollout(agent, args.env, num_steps, num_episodes, saver,
                not args.render)
    agent.stop()
Exemple #18
0
                "rollout_fragment_length": 1000,
                "train_batch_size": 4000,
                "model": {"use_lstm": args.use_lstm},
            }
        )

    checkpoint_path = CHECKPOINT_FILE.format(args.run)
    # Attempt to restore from checkpoint, if possible.
    if not args.no_restore and os.path.exists(checkpoint_path):
        checkpoint_path = open(checkpoint_path).read()
    else:
        checkpoint_path = None

    # Manual training loop (no Ray tune).
    if args.no_tune:
        algo_cls = get_algorithm_class(args.run)
        algo = algo_cls(config=config)

        if checkpoint_path:
            print("Restoring from checkpoint path", checkpoint_path)
            algo.restore(checkpoint_path)

        # Serving and training loop.
        ts = 0
        for _ in range(args.stop_iters):
            results = algo.train()
            print(pretty_print(results))
            checkpoint = algo.save()
            print("Last checkpoint", checkpoint)
            with open(checkpoint_path, "w") as f:
                f.write(checkpoint)
def ckpt_restore_test(alg_name, tfe=False, object_store=False, replay_buffer=False):
    config = CONFIGS[alg_name].copy()
    # If required, store replay buffer data in checkpoints as well.
    if replay_buffer:
        config["store_buffer_in_checkpoints"] = True

    frameworks = (["tf2"] if tfe else []) + ["torch", "tf"]
    for fw in framework_iterator(config, frameworks=frameworks):
        for use_object_store in [False, True] if object_store else [False]:
            print("use_object_store={}".format(use_object_store))
            cls = get_algorithm_class(alg_name)
            if "DDPG" in alg_name or "SAC" in alg_name:
                alg1 = cls(config=config, env="Pendulum-v1")
                alg2 = cls(config=config, env="Pendulum-v1")
            else:
                alg1 = cls(config=config, env="CartPole-v0")
                alg2 = cls(config=config, env="CartPole-v0")

            policy1 = alg1.get_policy()

            for _ in range(1):
                res = alg1.train()
                print("current status: " + str(res))

            # Check optimizer state as well.
            optim_state = policy1.get_state().get("_optimizer_variables")

            # Sync the models
            if use_object_store:
                alg2.restore_from_object(alg1.save_to_object())
            else:
                alg2.restore(alg1.save())

            # Compare optimizer state with re-loaded one.
            if optim_state:
                s2 = alg2.get_policy().get_state().get("_optimizer_variables")
                # Tf -> Compare states 1:1.
                if fw in ["tf2", "tf", "tfe"]:
                    check(s2, optim_state)
                # For torch, optimizers have state_dicts with keys=params,
                # which are different for the two models (ignore these
                # different keys, but compare all values nevertheless).
                else:
                    for i, s2_ in enumerate(s2):
                        check(
                            list(s2_["state"].values()),
                            list(optim_state[i]["state"].values()),
                        )

            # Compare buffer content with restored one.
            if replay_buffer:
                data = alg1.local_replay_buffer.replay_buffers[
                    "default_policy"
                ]._storage[42 : 42 + 42]
                new_data = alg2.local_replay_buffer.replay_buffers[
                    "default_policy"
                ]._storage[42 : 42 + 42]
                check(data, new_data)

            for _ in range(1):
                if "DDPG" in alg_name or "SAC" in alg_name:
                    obs = np.clip(
                        np.random.uniform(size=3),
                        policy1.observation_space.low,
                        policy1.observation_space.high,
                    )
                else:
                    obs = np.clip(
                        np.random.uniform(size=4),
                        policy1.observation_space.low,
                        policy1.observation_space.high,
                    )
                a1 = get_mean_action(alg1, obs)
                a2 = get_mean_action(alg2, obs)
                print("Checking computed actions", alg1, obs, a1, a2)
                if abs(a1 - a2) > 0.1:
                    raise AssertionError(
                        "algo={} [a1={} a2={}]".format(alg_name, a1, a2)
                    )
            # Stop both algos.
            alg1.stop()
            alg2.stop()
Exemple #20
0
def run_heuristic_vs_learned(args, use_lstm=False, algorithm="PG"):
    """Run heuristic policies vs a learned agent.

    The learned agent should eventually reach a reward of ~5 with
    use_lstm=False, and ~7 with use_lstm=True. The reason the LSTM policy
    can perform better is since it can distinguish between the always_same vs
    beat_last heuristics.
    """
    def select_policy(agent_id, episode, **kwargs):
        if agent_id == "player_0":
            return "learned"
        else:
            return random.choice(["always_same", "beat_last"])

    config = {
        "env": "RockPaperScissors",
        "gamma": 0.9,
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
        "num_workers": 0,
        "num_envs_per_worker": 4,
        "rollout_fragment_length": 10,
        "train_batch_size": 200,
        "metrics_num_episodes_for_smoothing": 200,
        "multiagent": {
            "policies_to_train": ["learned"],
            "policies": {
                "always_same":
                PolicySpec(policy_class=AlwaysSameHeuristic),
                "beat_last":
                PolicySpec(policy_class=BeatLastHeuristic),
                "learned":
                PolicySpec(config={
                    "model": {
                        "use_lstm": use_lstm
                    },
                    "framework": args.framework,
                }),
            },
            "policy_mapping_fn": select_policy,
        },
        "framework": args.framework,
    }
    cls = get_algorithm_class(algorithm) if isinstance(algorithm,
                                                       str) else algorithm
    algo = cls(config=config)
    for _ in range(args.stop_iters):
        results = algo.train()
        # Timesteps reached.
        if "policy_always_same_reward" not in results["hist_stats"]:
            reward_diff = 0
            continue
        reward_diff = sum(results["hist_stats"]["policy_learned_reward"])
        if results["timesteps_total"] > args.stop_timesteps:
            break
        # Reward (difference) reached -> all good, return.
        elif reward_diff > args.stop_reward:
            return

    # Reward (difference) not reached: Error if `as_test`.
    if args.as_test:
        raise ValueError(
            "Desired reward difference ({}) not reached! Only got to {}.".
            format(args.stop_reward, reward_diff))
            experiment["config"]["eager_tracing"] = True
        # experiment["config"]["callbacks"] = MemoryTrackingCallbacks

        # Move "env" specifier into config.
        experiment["config"]["env"] = experiment["env"]
        experiment.pop("env", None)

        # Print out the actual config.
        print("== Test config ==")
        print(yaml.dump(experiment))

        # Construct the trainer instance based on the given config.
        leaking = True
        try:
            ray.init(num_cpus=5, local_mode=args.local_mode)
            trainer = get_algorithm_class(experiment["run"])(
                experiment["config"])
            results = check_memory_leaks(
                trainer,
                to_check=set(args.to_check),
            )
            if not results:
                leaking = False
        finally:
            ray.shutdown()

        if not leaking:
            print("Memory leak test PASSED")
        else:
            print("Memory leak test FAILED. Exiting with Error.")
            sys.exit(1)