コード例 #1
0
    def play_wrapper(model_names, n_rounds):
        time_stamp = time.time()

        models = []
        for i, item in enumerate(model_names):
            models.append(
                magent.ProcessingModel(env, handles[i], item[1], 0, item[-1]))

        for i, item in enumerate(model_names):
            models[i].load(item[0], item[2])

        leftID, rightID = 0, 1
        result = 0
        total_num = np.zeros(2)
        for _ in range(n_rounds):
            round_num = play(env, handles, models, map_size, leftID, rightID)
            total_num += round_num
            leftID, rightID = rightID, leftID
            result += 1 if round_num[0] > round_num[1] else 0
        result = 1.0 * result

        for model in models:
            model.quit()

        return result / n_rounds, total_num / n_rounds, time.time(
        ) - time_stamp
コード例 #2
0
    env.set_render_dir("build/render")

    # two groups of agents
    handles = env.get_handles()

    # load models
    names = ["predator", "prey"]
    models = []

    for i in range(len(names)):
        models.append(
            magent.ProcessingModel(env,
                                   handles[i],
                                   names[i],
                                   20000 + i,
                                   4000,
                                   DeepQNetwork,
                                   batch_size=512,
                                   memory_size=2**22,
                                   target_update=1000,
                                   train_freq=4))

    # load if
    savedir = 'save_model'
    if args.load_from is not None:
        start_from = args.load_from
        print("load ... %d" % start_from)
        for model in models:
            model.load(savedir, start_from)
    else:
        start_from = 0
コード例 #3
0
ファイル: train_multi.py プロジェクト: goldenair/multii-agent
        raise NotImplementedError
    else:
        raise NotImplementedError

    # load models
    names = [
        args.name + "-l0", args.name + "-l1", args.name + "-r0",
        args.name + "-r1"
    ]
    models = []

    for i in range(len(names)):
        model_args = {'eval_obs': eval_obs[i]}
        model_args.update(base_args)
        models.append(
            magent.ProcessingModel(env, handles[i], names[i], 20000 + i, 1000,
                                   RLModel, **model_args))

    # load if
    savedir = 'save_model'
    if args.load_from is not None:
        start_from = args.load_from
        print("load ... %d" % start_from)
        for model in models:
            model.load(savedir, start_from)
    else:
        start_from = 0

    # print state info
    print(args)
    print("view_size", env.get_view_space(handles[0]))
    print("feature_size", env.get_feature_space(handles[0]))
コード例 #4
0
    else:
        eval_obs = [None, None]

    # init models
    names = [args.name + "-a", "battle"]
    batch_size = 512
    unroll_step = 16
    train_freq = 5

    models = []

    # load opponent
    if args.opponent >= 0:
        from models.tf_model import DeepQNetwork
        models.append(
            magent.ProcessingModel(env, handles[1], names[1], 20000, 0,
                                   DeepQNetwork))
        models[0].load("data/battle_model", args.opponent)
    else:
        models.append(
            magent.ProcessingModel(env, handles[1], names[1], 20000, 0,
                                   RandomActor))

    # load our model
    if args.alg == 'dqn':
        from models.tf_model import DeepQNetwork
        models.append(
            magent.ProcessingModel(env,
                                   handles[0],
                                   names[0],
                                   20001,
                                   1000,