예제 #1
0
def init_worker(actions_list=None):
    train_n_replicates = 1
    debug = True
    stop_iters = 200
    tf = False
    seeds = miscellaneous.get_random_seeds(train_n_replicates)
    exp_name, _ = log.log_in_current_day_dir("testing")

    rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, tf)
    rllib_config["env"] = FakeEnvWtCstReward
    rllib_config["env_config"]["max_steps"] = EPI_LENGTH
    rllib_config["seed"] = int(time.time())
    if actions_list is not None:
        for policy_id in FakeEnvWtCstReward({}).players_ids:
            policy_to_modify = list(
                rllib_config["multiagent"]["policies"][policy_id]
            )
            policy_to_modify[0] = make_FakePolicyWtDefinedActions(
                copy.deepcopy(actions_list)
            )
            rllib_config["multiagent"]["policies"][
                policy_id
            ] = policy_to_modify

    pg_trainer = PGTrainer(
        rllib_config, logger_creator=_get_logger_creator(exp_name)
    )
    return pg_trainer.workers._local_worker
예제 #2
0
def _get_hyparameters(debug,
                      env=None,
                      train_n_replicates=None,
                      against_naive_opp=False):
    if debug:
        train_n_replicates = 1
    elif train_n_replicates is None:
        train_n_replicates = 4

    seeds = miscellaneous.get_random_seeds(train_n_replicates)
    exp_name, _ = log.log_in_current_day_dir("LTFT")

    hparameters = {
        "seeds": seeds,
        "debug": debug,
        "exp_name": exp_name,
        "hiddens": [64],
        "log_n_points": 260,
        "clustering_distance": 0.2,
        "gamma": 0.96,
        "env_name": "IteratedPrisonersDilemma" if env is None else env,
        # "env_name": "CoinGame" if env is None else env,
        "reward_uncertainty_std": 0.1,
        # "against_evader_exploiter": None,
        "against_evader_exploiter": {
            "start_exploit": 0.75,
            "copy_weights_delay": 0.05,
        } if not against_naive_opp else None,
    }

    hparameters = _modify_hyperparams_for_the_selected_env(hparameters)

    return hparameters, exp_name
예제 #3
0
def main(debug):
    train_n_replicates = 1 if debug else 1
    seeds = miscellaneous.get_random_seeds(train_n_replicates)
    exp_name, _ = log.log_in_current_day_dir("DQN_CG_speed_search")

    env = "CoinGame"
    # env = "SSDMixedMotiveCoinGame"
    # welfare_to_use = None
    # welfare_to_use = postprocessing.WELFARE_UTILITARIAN
    welfare_to_use = postprocessing.WELFARE_INEQUITY_AVERSION

    if "SSDMixedMotiveCoinGame" in env:
        env_class = ssd_mixed_motive_coin_game.SSDMixedMotiveCoinGame
    else:
        env_class = coin_game.CoinGame

    hparams = _get_hyperparameters(seeds, debug, exp_name)

    rllib_config, stop_config = _get_rllib_configs(
        hparams, env_class=env_class
    )

    if welfare_to_use is not None:
        rllib_config = _modify_policy_to_use_welfare(
            rllib_config, welfare_to_use
        )

    rllib_config, stop_config = _add_search_to_config(
        rllib_config, stop_config, hparams
    )
    tune_analysis = _train_dqn_and_plot_logs(
        hparams, rllib_config, stop_config
    )

    return tune_analysis
예제 #4
0
def generate_eval_config(hp):
    hp_eval = copy.deepcopy(hp)

    hp_eval["min_iter_time_s"] = 3.0
    hp_eval["seed"] = miscellaneous.get_random_seeds(1)[0]
    hp_eval["batch_size"] = 1
    hp_eval["num_episodes"] = 100

    tune_config, stop, env_config = get_tune_config(hp_eval)
    tune_config['TuneTrainerClass'] = LOLAExact

    hp_eval["group_names"] = ["lola"]
    hp_eval["scale_multipliers"] = (1 / tune_config['trace_length'],
                                    1 / tune_config['trace_length'])
    hp_eval["jitter"] = 0.05

    if hp_eval["env"] == "IPD":
        hp_eval["env"] = IteratedPrisonersDilemma
        hp_eval["x_limits"] = (-3.5, 0.5)
        hp_eval["y_limits"] = (-3.5, 0.5)
    elif hp_eval["env"] == "IMP":
        hp_eval["env"] = IteratedMatchingPennies
        hp_eval["x_limits"] = (-1.0, 1.0)
        hp_eval["y_limits"] = (-1.0, 1.0)
    elif hp_eval["env"] == "AsymBoS":
        hp_eval["env"] = IteratedAsymBoS
        hp_eval["x_limits"] = (0.0, 4.0)
        hp_eval["y_limits"] = (0.0, 4.0)
    else:
        raise NotImplementedError()

    rllib_config_eval = {
        "env": hp_eval["env"],
        "env_config": env_config,
        "multiagent": {
            "policies": {
                env_config["players_ids"][0]:
                (policy.get_tune_policy_class(PGTorchPolicy),
                 hp_eval["env"](env_config).OBSERVATION_SPACE,
                 hp_eval["env"].ACTION_SPACE, {
                     "tune_config": copy.deepcopy(tune_config)
                 }),
                env_config["players_ids"][1]:
                (policy.get_tune_policy_class(PGTorchPolicy),
                 hp_eval["env"](env_config).OBSERVATION_SPACE,
                 hp_eval["env"].ACTION_SPACE, {
                     "tune_config": copy.deepcopy(tune_config)
                 }),
            },
            "policy_mapping_fn": lambda agent_id: agent_id,
            "policies_to_train": ["None"],
        },
        "seed": hp_eval["seed"],
        "min_iter_time_s": hp_eval["min_iter_time_s"],
    }

    policies_to_load = copy.deepcopy(env_config["players_ids"])
    trainable_class = LOLAExact

    return hp_eval, rllib_config_eval, policies_to_load, trainable_class, stop, env_config
def main(debug, stop_iters=2000, tf=False):
    train_n_replicates = 1 if debug else 1
    seeds = miscellaneous.get_random_seeds(train_n_replicates)
    exp_name, _ = log.log_in_current_day_dir("PPO_AsymCG")

    ray.init()

    stop = {
        "training_iteration": 2 if debug else stop_iters,
    }

    env_config = {
        "players_ids": ["player_red", "player_blue"],
        "max_steps": 20,
        "grid_size": 3,
        "get_additional_info": True,
    }

    rllib_config = {
        "env": AsymCoinGame,
        "env_config": env_config,
        "multiagent": {
            "policies": {
                env_config["players_ids"][0]:
                (None, AsymCoinGame(env_config).OBSERVATION_SPACE,
                 AsymCoinGame.ACTION_SPACE, {}),
                env_config["players_ids"][1]:
                (None, AsymCoinGame(env_config).OBSERVATION_SPACE,
                 AsymCoinGame.ACTION_SPACE, {}),
            },
            "policy_mapping_fn": lambda agent_id: agent_id,
        },
        # Size of batches collected from each worker.
        "rollout_fragment_length": 20,
        # Number of timesteps collected for each SGD round. This defines the size
        # of each SGD epoch.
        "train_batch_size": 512,
        "model": {
            "dim": env_config["grid_size"],
            "conv_filters": [[16, [3, 3], 1],
                             [32, [3, 3],
                              1]]  # [Channel, [Kernel, Kernel], Stride]]
        },
        "lr": 5e-3,
        "seed": tune.grid_search(seeds),
        "callbacks": log.get_logging_callbacks_class(),
        "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
        "framework": "tf" if tf else "torch",
    }

    tune_analysis = tune.run(PPOTrainer,
                             config=rllib_config,
                             stop=stop,
                             checkpoint_freq=0,
                             checkpoint_at_end=True,
                             name=exp_name)
    ray.shutdown()
    return tune_analysis
예제 #6
0
def _init_evaluator():
    rllib_config, stop_config = get_rllib_config(seeds=get_random_seeds(1))

    evaluator = self_and_cross_perf.SelfAndCrossPlayEvaluator(
        exp_name="testing_amTFT", )
    evaluator.define_the_experiment_to_run(evaluation_config=rllib_config,
                                           stop_config=stop_config,
                                           TrainerClass=PGTrainer)

    return evaluator
예제 #7
0
def main(debug, stop_iters=200, tf=False):
    train_n_replicates = 1 if debug else 1
    seeds = miscellaneous.get_random_seeds(train_n_replicates)
    exp_name, _ = log.log_in_current_day_dir("PG_IPD")

    ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)

    rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, tf)
    tune_analysis = tune.run(PGTrainer, config=rllib_config, stop=stop_config,
                             checkpoint_freq=0, checkpoint_at_end=True, name=exp_name)
    ray.shutdown()
    return tune_analysis
예제 #8
0
def main(debug, welfare=postprocessing.WELFARE_UTILITARIAN):
    train_n_replicates = 1 if debug else 1
    seeds = miscellaneous.get_random_seeds(train_n_replicates)
    exp_name, _ = log.log_in_current_day_dir("DQN_welfare_CG")

    hparams = dqn_coin_game._get_hyperparameters(seeds, debug, exp_name)
    rllib_config, stop_config = dqn_coin_game._get_rllib_configs(hparams)
    rllib_config = _modify_policy_to_use_welfare(rllib_config, welfare)

    tune_analysis = dqn_coin_game._train_dqn_and_plot_logs(
        hparams, rllib_config, stop_config)

    return tune_analysis
예제 #9
0
def main(debug):
    train_n_replicates = 1 if debug else 1
    seeds = miscellaneous.get_random_seeds(train_n_replicates)
    exp_name, _ = log.log_in_current_day_dir("DQN_CG")

    hparams = _get_hyperparameters(seeds, debug, exp_name)

    rllib_config, stop_config = _get_rllib_configs(hparams)

    tune_analysis = _train_dqn_and_plot_logs(
        hparams, rllib_config, stop_config)

    return tune_analysis
예제 #10
0
def main(debug):
    exp_name, _ = log.log_in_current_day_dir("L1BR_amTFT")

    train_n_replicates = 4 if debug else 8
    pool_of_seeds = miscellaneous.get_random_seeds(train_n_replicates)
    hparams = {
        "debug": debug,
        "filter_utilitarian": False,
        "train_n_replicates": train_n_replicates,
        "seeds": pool_of_seeds,
        "exp_name": exp_name,
        "n_steps_per_epi": 20,
        "bs_epi_mul": 4,
        "welfare_functions":
        [(postprocessing.WELFARE_UTILITARIAN, "utilitarian")],
        "amTFTPolicy": amTFT.amTFTRolloutsTorchPolicy,
        "explore_during_evaluation": True,
        "n_seeds_lvl0": train_n_replicates,
        "n_seeds_lvl1": train_n_replicates // 2,
        "gamma": 0.5,
        "lambda": 0.9,
        "alpha": 0.0,
        "beta": 1.0,
        "temperature_schedule": False,
        "debit_threshold": 4.0,
        "jitter": 0.05,
        "hiddens": [64],
        "env": matrix_sequential_social_dilemma.IteratedPrisonersDilemma,
        # "env": matrix_sequential_social_dilemma.IteratedAsymBoS,
        # "env": matrix_sequential_social_dilemma.IteratedAsymChicken,
        # "env": coin_game.CoinGame
        # "env": coin_game.AsymCoinGame

        # For training speed
        "min_iter_time_s": 0.0 if debug else 3.0,
        "overwrite_reward": True,
        "use_adam": False,
    }

    ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=hparams["debug"])

    hparams = amtft_various_env.modify_hyperparams_for_the_selected_env(
        hparams)
    lvl0_tune_analysis = train_lvl0_population(hp=hparams)
    tune_analysis_lvl1 = train_lvl1_agents(
        hp_lvl1=copy.deepcopy(hparams), tune_analysis_lvl0=lvl0_tune_analysis)
    print(tune_analysis_lvl1.results_df.columns)
    print(tune_analysis_lvl1.results_df.head())

    ray.shutdown()
예제 #11
0
def _init_evaluator():
    exp_name, _ = log.log_in_current_day_dir("testing")

    rllib_config, stop_config = get_rllib_config(seeds=get_random_seeds(1))

    evaluator = self_and_cross_perf.SelfAndCrossPlayEvaluator(
        exp_name=exp_name,
    )
    evaluator.define_the_experiment_to_run(
        evaluation_config=rllib_config,
        stop_config=stop_config,
        TrainerClass=PGTrainer,
    )

    return evaluator
예제 #12
0
def main(debug):
    train_n_replicates = 2 if debug else 40
    seeds = miscellaneous.get_random_seeds(train_n_replicates)

    exp_name, _ = log.log_in_current_day_dir("LOLA_Exact")

    hparams = {
        "load_plot_data": None,
        # Example "load_plot_data": ".../SameAndCrossPlay_save.p",
        "exp_name": exp_name,
        "train_n_replicates": train_n_replicates,
        "env": "IPD",
        # "env": "IMP",
        # "env": "AsymBoS",
        "num_episodes": 5 if debug else 50,
        "trace_length": 5 if debug else 200,
        "simple_net": True,
        "corrections": True,
        "pseudo": False,
        "num_hidden": 32,
        "reg": 0.0,
        "lr": 1.,
        "lr_correction": 1.0,
        "gamma": 0.96,
        "seed": tune.grid_search(seeds),
        "metric": "ret1",
        "with_linear_LR_decay_to_zero": False,
        "clip_update": None,

        # "with_linear_LR_decay_to_zero": True,
        # "clip_update": 0.1,
        # "lr": 0.001,
    }

    if hparams["load_plot_data"] is None:
        ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)
        tune_analysis_per_exp = train(hparams)
    else:
        tune_analysis_per_exp = None

    evaluate(tune_analysis_per_exp, hparams)
    ray.shutdown()
예제 #13
0
def _train_pg_in_ipd(train_n_replicates):
    debug = True
    stop_iters = 200
    tf = False
    seeds = miscellaneous.get_random_seeds(train_n_replicates)
    exp_name, _ = log.log_in_current_day_dir("testing")

    ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)

    rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, tf)
    tune_analysis = tune.run(PGTrainer,
                             config=rllib_config,
                             stop=stop_config,
                             checkpoint_freq=0,
                             checkpoint_at_end=True,
                             name=exp_name,
                             metric="episode_reward_mean",
                             mode="max")
    ray.shutdown()
    return tune_analysis, seeds
예제 #14
0
def get_hyperparameters(debug, env):
    exp_name, _ = log.log_in_current_day_dir("L1BR_amTFT")

    train_n_replicates = 4 if debug else 8
    pool_of_seeds = miscellaneous.get_random_seeds(train_n_replicates)
    hparams = {
        "debug": debug,
        "filter_utilitarian": False,
        "train_n_replicates": train_n_replicates,
        "seeds": pool_of_seeds,
        "exp_name": exp_name,
        "welfare_functions":
        [(postprocessing.WELFARE_UTILITARIAN, "utilitarian")],
        "amTFTPolicy": amTFT.AmTFTRolloutsTorchPolicy,
        "explore_during_evaluation": True,
        "n_seeds_lvl0": train_n_replicates,
        "n_seeds_lvl1": train_n_replicates // 2,
        "gamma": 0.96,
        "temperature_schedule": False,
        "jitter": 0.05,
        "hiddens": [64],
        "env_name": "IteratedPrisonersDilemma",
        # "env_name": "IteratedAsymBoS",
        # "env_name": "IteratedAsymChicken",
        # "env_name": "CoinGame",
        # "env_name": "AsymCoinGame",
        "overwrite_reward": True,
        "reward_uncertainty": 0.0,
    }

    if env is not None:
        hparams["env_name"] = env

    hparams = amtft_various_env.modify_hyperparams_for_the_selected_env(
        hparams)

    return hparams, exp_name
예제 #15
0
def main(debug):
    train_n_replicates = 1 if debug else 1
    seeds = miscellaneous.get_random_seeds(train_n_replicates)
    exp_name, _ = log.log_in_current_day_dir("LTFT_IPD")

    hparameters = {
        "n_epi": 10 if debug else 200,
        "n_steps_per_epi": 20,
        "bs_epi_mul": 4,
        "base_lr": 0.04,
        "spl_lr_mul": 10.0,
        "seeds": seeds,
        "debug": debug,
    }

    rllib_config, env_config, stop = get_rllib_config(hparameters)
    ray.init(num_cpus=os.cpu_count(), num_gpus=0)
    print("\n========== Training LTFT in self-play ==========\n")
    tune_analysis_self_play = ray.tune.run(DQNTrainer, config=rllib_config,
                           verbose=1, checkpoint_freq=0, stop=stop,
                           checkpoint_at_end=True, name=exp_name)

    print("\n========== Training LTFT against a naive opponent ==========\n")
    # Set player_col to use a naive policy
    rllib_config["multiagent"]["policies"][env_config["players_ids"][1]] = (
        None,
        IteratedPrisonersDilemma.OBSERVATION_SPACE,
        IteratedPrisonersDilemma.ACTION_SPACE,
        {}
    )
    tune_analysis_naive_opponent = ray.tune.run(DQNTrainer, config=rllib_config,
                           verbose=1, checkpoint_freq=0, stop=stop,
                           checkpoint_at_end=True, name=exp_name)

    ray.shutdown()
    return tune_analysis_self_play, tune_analysis_naive_opponent
예제 #16
0
def main(debug, train_n_replicates=None, filter_utilitarian=None):

    train_n_replicates = 1 if debug else train_n_replicates
    train_n_replicates = 40 if train_n_replicates is None else train_n_replicates
    n_times_more_utilitarians_seeds = 4
    pool_of_seeds = miscellaneous.get_random_seeds(
        train_n_replicates * (1 + n_times_more_utilitarians_seeds))
    exp_name, _ = log.log_in_current_day_dir("amTFT")
    hparams = {
        "debug":
        debug,
        "filter_utilitarian":
        filter_utilitarian if filter_utilitarian is not None else not debug,
        "train_n_replicates":
        train_n_replicates,
        "n_times_more_utilitarians_seeds":
        n_times_more_utilitarians_seeds,
        "load_plot_data":
        None,
        # Example: "load_plot_data": ".../SelfAndCrossPlay_save.p",
        "exp_name":
        exp_name,
        "n_steps_per_epi":
        20,
        "bs_epi_mul":
        4,
        "welfare_functions":
        [(postprocessing.WELFARE_INEQUITY_AVERSION, "inequity_aversion"),
         (postprocessing.WELFARE_UTILITARIAN, "utilitarian")],
        "seeds":
        pool_of_seeds,
        "amTFTPolicy":
        amTFT.amTFTRolloutsTorchPolicy,
        "explore_during_evaluation":
        True,
        "gamma":
        0.5,
        "lambda":
        0.9,
        "alpha":
        0.0,
        "beta":
        1.0,
        "temperature_schedule":
        False,
        "debit_threshold":
        4.0,
        "jitter":
        0.05,
        "hiddens": [64],

        # If not in self play then amTFT will be evaluated against a naive selfish policy
        "self_play":
        True,
        # "self_play": False, # Not tested

        # "env": matrix_sequential_social_dilemma.IteratedPrisonersDilemma,
        # "utilitarian_filtering_threshold": -2.5,
        "env":
        matrix_sequential_social_dilemma.IteratedAsymBoS,
        "utilitarian_filtering_threshold":
        3.2,
        # "env": matrix_sequential_social_dilemma.IteratedAsymChicken,
        # "utilitarian_filtering_threshold": ...,
        # "env": coin_game.CoinGame
        # "env": coin_game.AsymCoinGame
        # "utilitarian_filtering_threshold": ...,

        # For training speed
        "min_iter_time_s":
        0.0 if debug else 3.0,
        "overwrite_reward":
        True,
        "use_adam":
        False,
    }
    hparams = modify_hyperparams_for_the_selected_env(hparams)

    if hparams["load_plot_data"] is None:
        ray.init(num_cpus=os.cpu_count(),
                 num_gpus=0,
                 local_mode=hparams["debug"])

        # Train
        tune_analysis_per_welfare = train_for_each_welfare_function(hparams)
        # Eval & Plot
        analysis_metrics_per_mode = evaluate_self_and_cross_perf(
            tune_analysis_per_welfare, hparams)

        ray.shutdown()
    else:
        tune_analysis_per_welfare = None
        # Plot
        analysis_metrics_per_mode = evaluate_self_and_cross_perf(
            tune_analysis_per_welfare, hparams)

    return tune_analysis_per_welfare, analysis_metrics_per_mode
예제 #17
0
def get_hyperparameters(
    debug,
    train_n_replicates=None,
    filter_utilitarian=None,
    env=None,
    reward_uncertainty=0.0,
):
    if debug:
        train_n_replicates = 2
        n_times_more_utilitarians_seeds = 1
    elif train_n_replicates is None:
        n_times_more_utilitarians_seeds = 4
        train_n_replicates = 4
    else:
        n_times_more_utilitarians_seeds = 4

    n_seeds_to_prepare = train_n_replicates * (1 +
                                               n_times_more_utilitarians_seeds)
    pool_of_seeds = miscellaneous.get_random_seeds(n_seeds_to_prepare)
    exp_name, _ = log.log_in_current_day_dir("amTFT")
    hparams = {
        "debug":
        debug,
        "filter_utilitarian":
        filter_utilitarian if filter_utilitarian is not None else not debug,
        "seeds":
        pool_of_seeds,
        "train_n_replicates":
        train_n_replicates,
        "n_times_more_utilitarians_seeds":
        n_times_more_utilitarians_seeds,
        "exp_name":
        exp_name,
        "log_n_points":
        250,
        "load_plot_data":
        None,
        # Example: "load_plot_data": ".../SelfAndCrossPlay_save.p",
        "load_policy_data":
        None,
        # "load_policy_data": {
        #     "Util": [
        #         ".../IBP/amTFT/trials/"
        #         "DQN_AsymCoinGame_...",
        #         ".../IBP/amTFT/trials/"
        #         "DQN_AsymCoinGame_..."],
        #     'IA':[
        #         ".../temp/IBP/amTFT/trials/"
        #         "DQN_AsymCoinGame_...",
        #         ".../IBP/amTFT/trials/"
        #         "DQN_AsymCoinGame_..."],
        # },
        # "load_policy_data": {
        #     "Util": [
        #         "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible/amTFT"
        #         "/2021_03_28/19_38_55/utilitarian_welfare/coop"
        #         "/DQN_VectMixedMotiveCG_06231_00000_0_seed=1616960338_2021-03-29_00-52-23/checkpoint_250/checkpoint-250",
        #         # "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible/amTFT"
        #         # "/2021_03_24/18_22_47/utilitarian_welfare/coop"
        #         # "/DQN_VectMixedMotiveCG_e1de7_00001_1_seed=1616610171_2021-03-25_00-27-29/checkpoint_250/checkpoint-250",
        #         # "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible/amTFT"
        #         # "/2021_03_24/18_22_47/utilitarian_welfare/coop"
        #         # "/DQN_VectMixedMotiveCG_e1de7_00002_2_seed=1616610172_2021-03-25_00-27-29/checkpoint_250/checkpoint-250",
        #         ],
        #     'IA':[
        #         "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible"
        #         "/amTFT/2021_03_28/19_38_55/inequity_aversion_welfare/coop"
        #         "/DQN_VectMixedMotiveCG_d5a2a_00000_0_seed=1616960335_2021-03-28_21-23-26/checkpoint_250/checkpoint-250",
        #         # "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible"
        #         # "/amTFT/2021_03_24/18_22_47/inequity_aversion_welfare/coop"
        #         # "/DQN_VectMixedMotiveCG_9cfe6_00001_1_seed=1616610168_2021-03-24_20-22-11/checkpoint_250/checkpoint-250",
        #         # "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible"
        #         # "/amTFT/2021_03_24/18_22_47/inequity_aversion_welfare/coop"
        #         # "/DQN_VectMixedMotiveCG_9cfe6_00002_2_seed=1616610169_2021-03-24_20-22-11/checkpoint_250/checkpoint-250",
        #         ],
        # },
        # "load_policy_data": {
        #     "Util": [
        #         "~/ray_results/amTFT"
        #         "/2021_03_24/18_22_47/utilitarian_welfare/coop"
        #         "/DQN_VectMixedMotiveCG_e1de7_00000_0_seed=1616610170_2021-03-25_00-27-29/checkpoint_250/checkpoint-250",
        #         "~/ray_results/amTFT"
        #         "/2021_03_24/18_22_47/utilitarian_welfare/coop"
        #         "/DQN_VectMixedMotiveCG_e1de7_00001_1_seed=1616610171_2021-03-25_00-27-29/checkpoint_250/checkpoint-250",
        #         "~/ray_results/amTFT"
        #         "/2021_03_24/18_22_47/utilitarian_welfare/coop"
        #         "/DQN_VectMixedMotiveCG_e1de7_00002_2_seed=1616610172_2021-03-25_00-27-29/checkpoint_250/checkpoint-250",
        #     ],
        #     'IA': [
        #         "~/ray_results"
        #         "/amTFT/2021_03_24/18_22_47/inequity_aversion_welfare/coop"
        #         "/DQN_VectMixedMotiveCG_9cfe6_00000_0_seed=1616610167_2021-03-24_20-22-10/checkpoint_250/checkpoint-250",
        #         "~/ray_results"
        #         "/amTFT/2021_03_24/18_22_47/inequity_aversion_welfare/coop"
        #         "/DQN_VectMixedMotiveCG_9cfe6_00001_1_seed=1616610168_2021-03-24_20-22-11/checkpoint_250/checkpoint-250",
        #         "~/ray_results"
        #         "/amTFT/2021_03_24/18_22_47/inequity_aversion_welfare/coop"
        #         "/DQN_VectMixedMotiveCG_9cfe6_00002_2_seed=1616610169_2021-03-24_20-22-11/checkpoint_250/checkpoint-250",
        #     ],
        # },
        "amTFTPolicy":
        amTFT.AmTFTRolloutsTorchPolicy,
        "welfare_functions": [
            (postprocessing.WELFARE_INEQUITY_AVERSION, "inequity_aversion"),
            (postprocessing.WELFARE_UTILITARIAN, "utilitarian"),
        ],
        "jitter":
        0.05,
        "hiddens": [64],
        "gamma":
        0.96,
        # If not in self play then amTFT
        # will be evaluated against a naive selfish policy or an exploiter
        "self_play":
        True,
        # "self_play": False, # Not tested
        "env_name":
        "IteratedPrisonersDilemma" if env is None else env,
        # "env_name": "IteratedAsymBoS" if env is None else env,
        # "env_name": "CoinGame" if env is None else env,
        # "env_name": "AsymCoinGame" if env is None else env,
        # "env_name": "MixedMotiveCoinGame" if env is None else env,
        # "env_name": "SSDMixedMotiveCoinGame" if env is None else env,
        "overwrite_reward":
        True,
        "explore_during_evaluation":
        True,
        "reward_uncertainty":
        reward_uncertainty,
    }

    hparams = modify_hyperparams_for_the_selected_env(hparams)
    hparams["plot_keys"] = amTFT.PLOT_KEYS + hparams["plot_keys"]
    hparams["plot_assemblage_tags"] = (amTFT.PLOT_ASSEMBLAGE_TAGS +
                                       hparams["plot_assemblage_tags"])

    return hparams