コード例 #1
0
ファイル: train_battle.py プロジェクト: dkkim93/mfrl
def main(args):
    # Initialize the environment
    env = magent.GridWorld('battle', map_size=args.map_size)
    env.set_render_dir(
        os.path.join(BASE_DIR, 'examples/battle_model', 'build/render'))
    handles = env.get_handles()

    tf_config = tf.ConfigProto(allow_soft_placement=True,
                               log_device_placement=False)
    tf_config.gpu_options.allow_growth = True

    log_dir = os.path.join(BASE_DIR, 'data/tmp'.format(args.algo))
    model_dir = os.path.join(BASE_DIR, 'data/models/{}'.format(args.algo))

    start_from = 0

    sess = tf.Session(config=tf_config)
    models = [
        spawn_ai(args.algo, sess, env, handles[0], args.algo + '-me',
                 args.max_steps),
        spawn_ai(args.algo, sess, env, handles[1], args.algo + '-opponent',
                 args.max_steps)
    ]
    sess.run(tf.global_variables_initializer())
    runner = tools.Runner(sess,
                          env,
                          handles,
                          args.map_size,
                          args.max_steps,
                          models,
                          play,
                          render_every=args.save_every if args.render else 0,
                          save_every=args.save_every,
                          tau=0.01,
                          log_name=args.algo,
                          log_dir=log_dir,
                          model_dir=model_dir,
                          train=True)

    for k in range(start_from, start_from + args.n_round):
        eps = linear_decay(k, [0, int(args.n_round * 0.8), args.n_round],
                           [1, 0.2, 0.1])
        runner.run(eps, k)
コード例 #2
0
ファイル: battle.py プロジェクト: RicMat/mfrl
                                  'data/models/{}-1'.format(args.oppo))

    sess = tf.Session(config=tf_config)
    models = [
        spawn_ai(args.algo, sess, env, handles[0], args.algo + '-me',
                 args.max_steps),
        spawn_ai(args.oppo, sess, env, handles[1], args.oppo + '-opponent',
                 args.max_steps)
    ]
    sess.run(tf.global_variables_initializer())

    models[0].load(main_model_dir, step=args.idx[0])
    models[1].load(oppo_model_dir, step=args.idx[1])

    runner = tools.Runner(sess,
                          env,
                          handles,
                          args.map_size,
                          args.max_steps,
                          models,
                          battle,
                          render_every=0)
    win_cnt = {'main': 0, 'opponent': 0}

    for k in range(0, args.n_round):
        runner.run(0.0, k, win_cnt=win_cnt)

    print('\n[*] >>> WIN_RATE: [{0}] {1} / [{2}] {3}'.format(
        args.algo, win_cnt['main'] / args.n_round, args.oppo,
        win_cnt['opponent'] / args.n_round))
コード例 #3
0
                 args.max_steps),
        spawn_ai(args.algo, sess, env, handles[2], args.algo + '-opponent2',
                 args.max_steps),
        spawn_ai(args.algo, sess, env, handles[3], args.algo + '-opponent3',
                 args.max_steps)
    ]
    sess.run(tf.global_variables_initializer())
    if args.algo == 'mtmfq':
        runner = tools.Runner(
            sess,
            env,
            handles,
            args.map_size,
            args.max_steps,
            models,
            play2,
            render_every=args.save_every if args.render else 0,
            save_every=args.save_every,
            tau=0.01,
            log_name=args.algo,
            log_dir=log_dir,
            model_dir=model_dir,
            train=True)

    else:
        runner = tools.Runner(
            sess,
            env,
            handles,
            args.map_size,
            args.max_steps,
コード例 #4
0
ファイル: train_battle.py プロジェクト: bellmanequation/LSC
            MsgModels[1].load(oppo_msg_dir, step=args.idx)
    if args.crp != 'None':
        crp = True
    else:
        crp = False

    runner = tools.Runner(sess,
                          env,
                          handles,
                          args.map_size,
                          args.max_steps,
                          models,
                          MsgModels,
                          play,
                          render_every=args.save_every if args.render else 0,
                          save_every=args.save_every,
                          tau=0.01,
                          log_name=args.algo,
                          log_dir=log_dir,
                          model_dir=model_dir,
                          train=True,
                          len_nei=args.len_nei,
                          rewardtype=args.rewardtype,
                          crp=crp,
                          is_selfplay=True,
                          is_fix=False)

    for k in range(start_from, start_from + args.n_round):
        eps = magent.utility.piecewise_decay(k, [0, 700, 1400, 8000],
                                             [1, 0.3, 0.05, 0.01])
        runner.run(eps, k)