예제 #1
0
     from magent.builtin.tf_model import DeepQNetwork
     models.append(
         DeepQNetwork(env,
                      tiger_handle,
                      args.name,
                      batch_size=batch_size,
                      memory_size=2**20,
                      learning_rate=4e-4))
     step_batch_size = None
 elif args.alg == 'drqn':
     from magent.builtin.tf_model import DeepRecurrentQNetwork
     models.append(
         DeepRecurrentQNetwork(env,
                               tiger_handle,
                               "tiger",
                               batch_size=batch_size / unroll,
                               unroll_step=unroll,
                               memory_size=20000,
                               learning_rate=4e-4))
     step_batch_size = None
 elif args.alg == 'a2c':
     from magent.builtin.mx_model import AdvantageActorCritic
     step_batch_size = int(10 * args.map_size * args.map_size * 0.01)
     models.append(
         AdvantageActorCritic(env,
                              tiger_handle,
                              "tiger",
                              batch_size=step_batch_size,
                              learning_rate=1e-2))
 else:
     raise NotImplementedError
예제 #2
0
                         handles[0],
                         "battle",
                         batch_size=batch_size,
                         learning_rate=3e-4,
                         memory_size=2**21,
                         target_update=target_update,
                         train_freq=train_freq,
                         eval_obs=eval_obs))
    elif args.alg == 'drqn':
        from magent.builtin.tf_model import DeepRecurrentQNetwork
        models.append(
            DeepRecurrentQNetwork(env,
                                  handles[0],
                                  "battle",
                                  learning_rate=3e-4,
                                  batch_size=batch_size / unroll_step,
                                  unroll_step=unroll_step,
                                  memory_size=2 * 8 * 625,
                                  target_update=target_update,
                                  train_freq=train_freq,
                                  eval_obs=eval_obs))
    else:
        # see train_against.py to know how to use a2c
        raise NotImplementedError

    models.append(models[0])

    # load if
    savedir = 'save_model'
    if args.load_from is not None:
        start_from = args.load_from
        print("load ... %d" % start_from)
예제 #3
0
        models.append(
            DeepQNetwork(env,
                         handles[0],
                         "selfplay",
                         batch_size=batch_size,
                         memory_size=2**20,
                         target_update=target_update,
                         train_freq=train_freq,
                         eval_obs=eval_obs))
    elif args.alg == 'drqn':
        models.append(
            DeepRecurrentQNetwork(env,
                                  handles[0],
                                  "selfplay",
                                  batch_size=batch_size / unroll_step,
                                  unroll_step=unroll_step,
                                  memory_size=2 * 8 * 625,
                                  target_update=target_update,
                                  train_freq=train_freq,
                                  eval_obs=eval_obs))
    else:
        raise NotImplementedError

    models.append(models[0])

    # load if
    savedir = 'save_model'
    if args.load_from is not None:
        start_from = args.load_from
        print("load ... %d" % start_from)
        for model in models: