Example #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("dst")
    parser.add_argument("party_file", nargs="+")
    parser.add_argument("--match_count",
                        type=int,
                        default=100,
                        help="1パーティあたりの対戦回数")
    args = parser.parse_args()
    context.init()
    dummy_party = load_pickle(args.party_file[0])["parties"][0]["party"]
    env = PokeEnv(
        dummy_party, [dummy_party],
        feature_types="enemy_type hp_ratio nv_condition rank".split(" "))
    parties = []
    action_samplers = []
    metadata = []
    print("loading parties")
    for pf in args.party_file:
        p, a, m = load_parties_agents(env, pf)
        parties.extend(p)
        action_samplers.extend(a)
        metadata.extend(m)
    print("rating")
    rates, log = rating_battle(env, parties, action_samplers, args.match_count)
    save_pickle(
        {
            "parties": parties,
            "party_metadatas": metadata,
            "rates": rates,
            "log": log
        }, args.dst)
Example #2
0
def load_parties_agents(env: PokeEnv, party_file_path: str):
    """
    パーティおよびそのエージェントをロードする。
    :param party_file_path:
    :return:
    """
    parties_with_meta = load_pickle(party_file_path)["parties"]
    parties = []
    action_samplers = []
    metadata = []
    for party_with_meta in parties_with_meta:
        party = party_with_meta["party"]
        party_uuid = party_with_meta["uuid"]
        agent_dirs = glob.glob(
            os.path.join(os.path.dirname(party_file_path), party_uuid,
                         "*_finish"))
        if len(agent_dirs) > 0:
            action_sampler = load_agent(env, agent_dirs[0])
            parties.append(party)
            action_samplers.append(action_sampler)
            metadata.append({"uuid": party_uuid, "policy": "rl"})
        # ランダム行動エージェント
        # 技の数によりランダムサンプリングが変動
        n_moves = len(party.pokes[0].moves)
        parties.append(party)
        action_samplers.append(generate_random_action_sampler(n_moves))
        metadata.append({"uuid": party_uuid, "policy": "random"})
    return parties, action_samplers, metadata
Example #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("dst_dir")
    parser.add_argument("seed_party")
    parser.add_argument("baseline_party_pool", help="レーティング測定相手パーティ群")
    parser.add_argument("baseline_party_rate", help="レーティング測定相手パーティ群のレーティング")
    # parser.add_argument("n_party", type=int)
    parser.add_argument("--rule", choices=[r.name for r in PartyRule], default=PartyRule.LV55_1.name)
    parser.add_argument("--neighbor", type=int, default=10, help="生成する近傍パーティ数")
    parser.add_argument("--iter", type=int, default=100, help="iteration数")
    parser.add_argument("--match_count", type=int, default=100, help="1パーティあたりの対戦回数")
    parser.add_argument("-j", type=int, help="並列処理数")
    args = parser.parse_args()
    context.init()
    baseline_parties, baseline_rates = load_party_rate(args.baseline_party_pool, args.baseline_party_rate)
    partygen = PartyGenerator(PartyRule[args.rule])
    results = []
    os.makedirs(args.dst_dir)
    seed_parties = [p["party"] for p in load_pickle(args.seed_party)["parties"]]
    with Pool(processes=args.j, initializer=process_init) as pool:
        args_list = []
        for seed_party in seed_parties:
            args_list.append((partygen, seed_party, baseline_parties, baseline_rates, args.neighbor, args.iter,
                              args.match_count, args.dst_dir))
        for generated_party, rate, party_uuid, history_result in pool.imap_unordered(hill_climbing_mp, args_list):
            # 1サンプル生成ごとに呼ばれる(全計算が終わるまで待たない)
            results.append(
                {"party": generated_party, "uuid": party_uuid, "optimize_rate": rate, "history": history_result})
            print(f"completed {len(results)} / {len(seed_parties)}")
    save_pickle({"parties": results}, os.path.join(args.dst_dir, "parties.bin"))
Example #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("dst")
    parser.add_argument("n_party", type=int)
    parser.add_argument("--rule",
                        choices=[r.name for r in PartyRule],
                        default=PartyRule.LV55_1.name)
    parser.add_argument("--move_value", help="estimate_move_value.pyで生成したスコア")
    parser.add_argument("--temperature",
                        type=float,
                        help="move_valueを使用する場合のsoftmax temperature")
    args = parser.parse_args()
    context.init()
    if args.move_value:
        assert args.temperature is not None
        partygen = PartyGeneratorWithMoveValue(
            load_pickle(args.move_value)["avg"], args.temperature,
            PartyRule[args.rule])
    else:
        partygen = PartyGenerator(PartyRule[args.rule])
    parties = [{
        "party": partygen.generate(),
        "uuid": str(uuid.uuid4())
    } for i in range(args.n_party)]
    save_pickle({"parties": parties}, args.dst)
Example #5
0
def filter_party(out_prefix, parties_file, rates_file, count):
    parties = load_pickle(parties_file)["parties"]
    uuid_rates = load_pickle(rates_file)["rates"]
    rate_idxs = []
    for i, party_data in enumerate(parties):
        rate = uuid_rates[party_data["uuid"]]
        rate_idxs.append((rate, i))

    if count < 0:
        count = len(rate_idxs) // 2
    top_rate_idxs = list(sorted(rate_idxs, reverse=True))[:count]
    out_parties = []
    out_uuid_rates = {}
    for _, idx in top_rate_idxs:
        party_data = parties[idx]
        out_parties.append(party_data)
        out_uuid_rates[party_data["uuid"]] = uuid_rates[party_data["uuid"]]

    save_pickle({"parties": out_parties}, out_prefix + ".bin")
    save_pickle({"rates": out_uuid_rates}, out_prefix + "_rate.bin")
Example #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("dst")
    parser.add_argument("friend_pool")
    parser.add_argument("enemy_pool")
    parser.add_argument("--count",
                        type=int,
                        default=-1,
                        help="いくつのパーティについて方策を学習するか")
    parser.add_argument("--skip", type=int, default=0)
    args = parser.parse_args()
    context.init()
    friend_pool = [
        p["party"] for p in load_pickle(args.friend_pool)["parties"]
    ]  # type: List[Party]
    enemy_pool = [p["party"] for p in load_pickle(args.enemy_pool)["parties"]
                  ]  # type: List[Party]
    os.makedirs(args.dst)
    count = args.count
    if count < 0:
        count = len(friend_pool) - args.skip
    for i in range(args.skip, args.skip + count):
        outdir = os.path.join(args.dst, f"party_{i}")
        train(outdir, friend_pool[i], enemy_pool)
Example #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("dst")
    parser.add_argument("party_file", nargs="+")
    parser.add_argument("--match_count",
                        type=int,
                        default=100,
                        help="1パーティあたりの対戦回数")
    args = parser.parse_args()
    context.init()
    parties = []
    for party_file in args.party_file:
        file_parties = load_pickle(party_file)["parties"]
        parties.extend(file_parties)
    party_bodies = [p["party"] for p in parties]
    rates = rating_battle(party_bodies, args.match_count)
    uuid_rates = {p["uuid"]: r for p, r in zip(parties, rates)}
    save_pickle({"rates": uuid_rates}, args.dst)
Example #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("agents_pool")
    parser.add_argument("friend_pool")
    parser.add_argument("test_enemy_pool")
    parser.add_argument("test_enemy_pool_rate")
    parser.add_argument("--count", type=int, default=-1, help="いくつのパーティについて方策を学習するか")
    parser.add_argument("--skip", type=int, default=0)
    args = parser.parse_args()
    context.init()
    friend_pool = [p["party"] for p in load_pickle(args.friend_pool)["parties"]]  # type: List[Party]
    test_enemy_pool, test_enemy_pool_rates = load_party_rate(args.test_enemy_pool, args.test_enemy_pool_rate)
    count = args.count
    if count < 0:
        count = len(friend_pool) - args.skip
    for i in range(args.skip, args.skip + count):
        party_dir = os.path.join(args.agents_pool, f"party_{i}")
        outdir = glob.glob(os.path.join(party_dir, "*_finish"))[0]
        friend_party = friend_pool[i]
        rates = eval_agent(outdir, friend_party, test_enemy_pool, test_enemy_pool_rates)
        print(friend_party)
        print(rates)
        result = {"rates": rates, "party_str": str(friend_party)}
        save_yaml(result, os.path.join(party_dir, "eval.yaml"))