Esempio n. 1
0
def validate_config(config):
    check_config_and_setup_param_noise(config=config)

    if config["callbacks"]["on_episode_end"]:
        episode_end_callback = config["callbacks"]["on_episode_end"]
    else:
        episode_end_callback = None

    def on_episode_end(params):
        # Add policy win rates metrics
        if episode_end_callback is not None:
            episode_end_callback(params)
        episode = params['episode']

        agent_last_infos = {
            policy_key: episode.last_info_for(agent_id)
            for agent_id, policy_key in episode.agent_rewards.keys()
        }
        for policy_key, info in agent_last_infos.items():
            if info is not None:
                if info['game_result'] == 'won':
                    episode.custom_metrics[policy_key + '_win'] = 1.0
                else:
                    episode.custom_metrics[policy_key + '_win'] = 0.0

    config["callbacks"]["on_episode_end"] = on_episode_end

    for eval_name, eval_config in config['extra_evaluation_configs'].items():
        if 'env_config' not in eval_config or not eval_config['env_config']:
            eval_config['env_config'] = {}
        eval_config['env_config'] = with_base_config(
            base_config=config["default_extra_eval_env_config"],
            extra_config=eval_config['env_config'])
Esempio n. 2
0
def ppo_custom_eval_trainer_validate_config(config):

    if config['sgd_minibatch_size'] == -1:
        config['sgd_minibatch_size'] = config['train_batch_size']

    validate_config(config)

    if config["redo_invalid_games"]:
        assert config[
            'num_envs_per_worker'] == 1, "redo_invalid_games requires num_envs_per_worker to be set to 1"
        assert config[
            'sample_batch_size'] == 1, "redo_invalid_games requires sample_batch_size to be set to 1"
        assert config[
            'batch_mode'] == 'complete_episodes', "redo_invalid_games requires batch_mode to be set to \'complete_episodes\'"

        if config["callbacks"]["on_sample_end"]:
            sample_end_callback = config["callbacks"]["on_sample_end"]
        else:
            sample_end_callback = None

        config["callbacks"]["on_sample_end"] = tune.function(
            get_redo_sample_if_game_result_was_invalid_worker_callback(
                sample_end_callback))

    if config["callbacks"]["on_episode_end"]:
        episode_end_callback = config["callbacks"]["on_episode_end"]
    else:
        episode_end_callback = None

    def on_episode_end(params):
        # Add policy win rates metrics
        if episode_end_callback is not None:
            episode_end_callback(params)
        episode = params['episode']

        agent_last_infos = {
            policy_key: episode.last_info_for(agent_id)
            for agent_id, policy_key in episode.agent_rewards.keys()
        }
        for policy_key, info in agent_last_infos.items():
            if info is not None:
                if info['game_result'] == 'won':
                    episode.custom_metrics[policy_key + '_win'] = 1.0
                else:
                    episode.custom_metrics[policy_key + '_win'] = 0.0

    config["callbacks"]["on_episode_end"] = on_episode_end

    for eval_name, eval_config in config['extra_evaluation_configs'].items():
        if 'env_config' not in eval_config or not eval_config['env_config']:
            eval_config['env_config'] = {}
        eval_config['env_config'] = with_base_config(
            base_config=config["default_extra_eval_env_config"],
            extra_config=eval_config['env_config'])
Esempio n. 3
0
File: ddppo.py Progetto: zqxyz73/ray
Note that unlike the paper, we currently do not implement straggler mitigation.
"""

# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_base_config(ppo.DEFAULT_CONFIG, {
    # During the sampling phase, each rollout worker will collect a batch
    # `sample_batch_size * num_envs_per_worker` steps in size.
    "sample_batch_size": 100,
    # Vectorize the env (should enable by default since each worker has a GPU).
    "num_envs_per_worker": 5,
    # During the SGD phase, workers iterate over minibatches of this size.
    # The effective minibatch size will be `sgd_minibatch_size * num_workers`.
    "sgd_minibatch_size": 50,
    # Number of SGD epochs per optimization round.
    "num_sgd_iter": 10,

    # *** WARNING: configs below are DDPPO overrides over PPO; you
    #     shouldn't need to adjust them. ***
    "use_pytorch": True,  # DDPPO requires PyTorch distributed.
    "num_gpus": 0,  # Learning is no longer done on the driver process, so
                    # giving GPUs to the driver does not make sense!
    "num_gpus_per_worker": 1,  # Each rollout worker gets a GPU.
    "truncate_episodes": True,  # Require evenly sized batches. Otherwise,
                                # collective allreduce could fail.
    "train_batch_size": -1,  # This is auto set based on sample batch size.
})
# __sphinx_doc_end__
# yapf: enable


def validate_config(config):
Esempio n. 4
0
File: appo.py Progetto: locussam/ray
DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
    # Whether to use V-trace weighted advantages. If false, PPO GAE advantages
    # will be used instead.
    "vtrace": False,

    # == These two options only apply if vtrace: False ==
    # If true, use the Generalized Advantage Estimator (GAE)
    # with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
    "use_gae": True,
    # GAE(lambda) parameter
    "lambda": 1.0,

    # == PPO surrogate loss options ==
    "clip_param": 0.4,

    # == IMPALA optimizer params (see documentation in impala.py) ==
    "sample_batch_size": 50,
    "train_batch_size": 500,
    "min_iter_time_s": 10,
    "num_workers": 2,
    "num_gpus": 1,
    "num_data_loader_buffers": 1,
    "minibatch_buffer_size": 1,
    "num_sgd_iter": 1,
    "replay_proportion": 0.0,
    "replay_buffer_num_slots": 100,
    "learner_queue_size": 16,
    "max_sample_requests_in_flight_per_worker": 2,
    "broadcast_interval": 1,
    "grad_clip": 40.0,
    "opt_type": "adam",
    "lr": 0.0005,
    "lr_schedule": None,
    "decay": 0.99,
    "momentum": 0.0,
    "epsilon": 0.1,
    "vf_loss_coeff": 0.5,
    "entropy_coeff": 0.01,
})
Esempio n. 5
0
    "export_policy_weights_ids": [],
    "redo_invalid_games": False,
    "wandb": {},
    "ed": None,
    "policy_catalog": None,
    "eq_iters": None,
    "adaptive_pval_test": False,
    "br_thres": None,
    "eq_thres": None,
    "br_eval_against_policy": None,
    "thres_is_pval": None,
    "adaptive_pval": None
}

PG_CUSTOM_EVAL_TRAINER_DEFAULT_CONFIG = with_base_config(
    base_config=DEFAULT_CONFIG,
    extra_config=ppo_custom_eval_trainer_added_config_items)

mixins = [CustomEvaluationsTrainerMixin, WeightsUtilsTrainerMixin]

SACTrainer = GenericOffPolicyTrainer.with_updates(
    name="SACDiscrete",
    default_config=PG_CUSTOM_EVAL_TRAINER_DEFAULT_CONFIG,
    default_policy=SACTFPolicy,
    validate_config=validate_config,
    before_init=pg_custom_eval_trainer_before_init,
    after_init=pg_custom_eval_trainer_after_init,
    after_optimizer_step=after_optimizer_step,
    collect_metrics_fn=collect_metrics,
    make_policy_optimizer=make_optimizer,
    mixins=mixins)
Esempio n. 6
0
    @override(VTracePolicyGraph)
    def _get_loss_inputs_dict(self, batch):
        feed = {}

        def add_feed(ph, val):
            feed[ph] = val

        util.deepZipWith(add_feed, self._loss_input_dict, batch)
        return feed


DEFAULT_CONFIG = trainer.with_base_config(
    impala.DEFAULT_CONFIG, {
        "data_path": None,
        "autoregressive": False,
        "residual": False,
        "imitation": True,
    })


class DummyTrainer(impala.ImpalaTrainer):
    """Imitation learning."""

    _name = "IMITATION"
    _default_config = DEFAULT_CONFIG
    _policy_graph = ImitationPolicyGraph

    def __setstate__(self, state):
        self.local_evaluator.policy_map[DEFAULT_POLICY_ID].set_state(state)
Esempio n. 7
0
def measure_exploitability_of_metanashes_as_they_become_available():
    logger = get_logger()

    storage_client = connect_storage_client()

    worker_id = f"Exploitability_Tracker_{gethostname()}_pid_{os.getpid()}_{datetime_str()}"

    manager_interface = ConsoleManagerInterface(
        server_host=MANAGER_SEVER_HOST,
        port=MANAGER_PORT,
        worker_id=worker_id,
        storage_client=storage_client,
        minio_bucket_name=BUCKET_NAME,
        minio_local_dir=DEFAULT_LOCAL_SAVE_PATH)

    logger.info(f"Started worker \'{worker_id}\'")

    # If you use ray for more than just this single example fn, you'll need to move ray.init to the top of your main()
    ray.init(address=os.getenv('RAY_HEAD_NODE'),
             ignore_reinit_error=True,
             local_mode=True)

    model_config_file_path, _ = maybe_download_object(
        storage_client=storage_client,
        bucket_name=BUCKET_NAME,
        object_name=MODEL_CONFIG_KEY,
        force_download=False)

    with open(model_config_file_path, 'r') as config_file:
        model_config = json.load(fp=config_file)

    example_env = PokerMultiAgentEnv(env_config=POKER_ENV_CONFIG)

    logger.info("\n\n\n\n\n__________________________________________\n"
                f"LAUNCHED FOR {POKER_GAME_VERSION}\n"
                f"__________________________________________\n\n\n\n\n")

    obs_space = example_env.observation_space
    act_space = example_env.action_space

    preprocessor = StrategoDictFlatteningPreprocessor(obs_space=obs_space)
    graph = tf.Graph()
    sess = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}),
                      graph=graph)

    def fetch_logits(policy):
        return {
            "behaviour_logits": policy.model.last_output(),
        }

    _policy_cls = POLICY_CLASS.with_updates(
        extra_action_fetches_fn=fetch_logits)

    with graph.as_default():
        with sess.as_default():
            policy = _policy_cls(obs_space=preprocessor.observation_space,
                                 action_space=act_space,
                                 config=with_common_config({
                                     'model':
                                     with_base_config(
                                         base_config=MODEL_DEFAULTS,
                                         extra_config=model_config),
                                     'env':
                                     POKER_ENV,
                                     'env_config':
                                     POKER_ENV_CONFIG,
                                     'custom_preprocessor':
                                     STRATEGO_PREPROCESSOR
                                 }))

    def set_policy_weights(weights_key):
        weights_file_path, _ = maybe_download_object(
            storage_client=storage_client,
            bucket_name=BUCKET_NAME,
            object_name=weights_key,
            force_download=False)
        policy.load_model_weights(weights_file_path)

    print("(Started Successfully)")

    last_payoff_table_key = None
    while True:
        payoff_table, payoff_table_key = manager_interface.get_latest_payoff_table(
            infinite_retry_on_error=True)
        if payoff_table_key == last_payoff_table_key:
            time.sleep(20)
            continue
        last_payoff_table_key = payoff_table_key

        metanash_probs, _, _ = get_fp_metanash_for_latest_payoff_table(
            manager_interface=manager_interface,
            fp_iters=20000,
            accepted_opponent_policy_class_names=[POLICY_CLASS_NAME],
            accepted_opponent_model_config_keys=[POKER_ENV_CONFIG],
            add_payoff_matrix_noise_std_dev=0.000,
            mix_with_uniform_dist_coeff=None,
            p_or_lower_rounds_to_zero=0.0)

        if metanash_probs is not None:
            policy_weights_keys = payoff_table.get_ordered_keys_in_payoff_matrix(
            )

            policy_dict = {
                key: prob
                for key, prob in zip(policy_weights_keys, metanash_probs)
            }

            exploitabilitly = measure_exploitability_nonlstm(
                rllib_policy=policy,
                poker_game_version=POKER_GAME_VERSION,
                policy_mixture_dict=policy_dict,
                set_policy_weights_fn=set_policy_weights)
            print(f"Exploitability: {exploitabilitly}")
Esempio n. 8
0
    def get_policy_fn(stratego_env_config):

        from mprl.utility_services.cloud_storage import maybe_download_object
        from mprl.rl.sac.sac_policy import SACDiscreteTFPolicy
        from mprl.rl.ppo.ppo_stratego_model_policy import PPOStrategoModelTFPolicy
        from mprl.rl.common.stratego_preprocessor import STRATEGO_PREPROCESSOR, StrategoDictFlatteningPreprocessor
        from ray.rllib.agents.trainer import with_common_config, with_base_config
        from ray.rllib.models.catalog import MODEL_DEFAULTS
        from mprl.rl.common.sac_spatial_stratego_model import SAC_SPATIAL_STRATEGO_MODEL
        import ray
        from ray.rllib.utils import try_import_tf
        import json
        import os
        tf = try_import_tf()

        from tensorflow.python.client import device_lib

        def get_available_gpus():
            local_device_protos = device_lib.list_local_devices()
            return [x.name for x in local_device_protos if x.device_type == 'GPU']


        # If you use ray for more than just this single example fn, you'll need to move ray.init to the top of your main()
        ray.init(address=os.getenv('RAY_HEAD_NODE'), ignore_reinit_error=True, local_mode=True)

        if policy_class_name == 'PPOStrategoModelTFPolicy':
            _policy_class = PPOStrategoModelTFPolicy
        elif policy_class_name == 'SACDiscreteTFPolicy':
            _policy_class = SACDiscreteTFPolicy
        else:
            raise NotImplementedError(f"Eval for policy class \'{policy_class_name}\' not implemented.")

        if model_config_object_key:
            with download_lock:
                model_config_file_path, _ = maybe_download_object(storage_client=storage_client,
                                                                  bucket_name=minio_bucket_name,
                                                                  object_name=model_config_object_key,
                                                                  force_download=False)

                with open(model_config_file_path, 'r') as config_file:
                    model_config = json.load(fp=config_file)
        else:
            model_config = manual_config

        example_env = stratego_env_config['env_class'](env_config=stratego_env_config)
        obs_space = example_env.observation_space
        act_space = example_env.action_space

        preprocessor = StrategoDictFlatteningPreprocessor(obs_space=obs_space)


        graph = tf.Graph()

        if os.getenv("EVALUATOR_USE_GPU") == 'true':
            gpu = 1
        else:
            gpu = 0

        config = tf.ConfigProto(device_count={'GPU': gpu})
        if gpu:
            config.gpu_options.allow_growth = True
        sess = tf.Session(config=config, graph=graph)

        with graph.as_default():
            with sess.as_default():
                policy = _policy_class(
                    obs_space=preprocessor.observation_space,
                    action_space=act_space,
                    config=with_common_config({
                        'model': with_base_config(base_config=MODEL_DEFAULTS, extra_config=model_config),
                        'env': POKER_ENV,
                        'env_config': stratego_env_config,
                        'custom_preprocessor': STRATEGO_PREPROCESSOR,
                    }))

                if model_weights_object_key:
                    with download_lock:
                        weights_file_path, _ = maybe_download_object(storage_client=storage_client,
                                                                     bucket_name=minio_bucket_name,
                                                                     object_name=model_weights_object_key,
                                                                     force_download=False)
                        policy.load_model_weights(weights_file_path)
                    policy.current_model_weights_key = weights_file_path
                else:
                    policy.current_model_weights_key = None

        def policy_fn(observation, policy_state=None):
            if policy_state is None:
                policy_state = policy.get_initial_state()

            current_player_perspective_action_index, policy_state, _ = policy.compute_single_action(
                obs=preprocessor.transform(observation),
                state=policy_state)

            return current_player_perspective_action_index, policy_state

        if population_policy_keys_to_selection_probs is not None:

            def sample_new_policy_weights_from_population():
                new_policy_key = np.random.choice(a=list(population_policy_keys_to_selection_probs.keys()),
                                                  p=list(population_policy_keys_to_selection_probs.values()))
                if new_policy_key != policy.current_model_weights_key:
                    with download_lock:
                        weights_file_path, _ = maybe_download_object(storage_client=storage_client,
                                                                     bucket_name=minio_bucket_name,
                                                                     object_name=new_policy_key,
                                                                     force_download=False)
                        policy.load_model_weights(weights_file_path)
                        logger.debug(f"Sampling new population weights from {new_policy_key}")
                    policy.current_model_weights_key = new_policy_key

            return policy_name, policy_fn, sample_new_policy_weights_from_population

        # policy name must be unique
        return policy_name, policy_fn
Esempio n. 9
0
def with_rllib_info(info: Info) -> Info:
    """Merge info with RLlib's common parameters' info."""
    info = with_base_config(COMMON_INFO, info)
    info = tree.map_structure(
        (lambda x: inspect.cleandoc(x) if isinstance(x, str) else x), info)
    return info