예제 #1
0
파일: run_atari.py 프로젝트: lhm3561/OSS
def train(env_id, num_timesteps, seed, policy, lrschedule, num_cpu):
    def make_env(rank):
        def _thunk():
            env = make_atari(env_id)
            env.seed(seed + rank)
            env = bench.Monitor(
                env,
                logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
            gym.logger.setLevel(logging.WARN)
            return wrap_deepmind(env)

        return _thunk

    set_global_seeds(seed)
    env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
    if policy == 'cnn':
        policy_fn = AcerCnnPolicy
    elif policy == 'lstm':
        policy_fn = AcerLstmPolicy
    else:
        print("Policy {} not implemented".format(policy))
        return
    learn(policy_fn,
          env,
          seed,
          total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule)
    env.close()
예제 #2
0
def train(env_id, num_timesteps, seed, policy, lrschedule, num_cpu, perform, use_expert, save_networks, learn_time, expert_buffer_size):
    def make_env(rank):
        def _thunk():
            env = make_atari(env_id)
            env.seed(seed + rank)
            env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
            gym.logger.setLevel(logging.WARN)
            return wrap_deepmind(env)
        return _thunk
    set_global_seeds(seed)
    env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
    if policy == 'cnn':
        policy_fn = AcerCnnPolicy
    elif policy == 'lstm':
        policy_fn = AcerLstmPolicy
    else:
        print("Policy {} not implemented".format(policy))
        return

    network_saving_dir = os.path.join('./saved_networks', env_id)+'/'
    if not os.path.exists(network_saving_dir):
        os.makedirs(network_saving_dir)

    learn(policy_fn, env, seed, env_id, learn_time, expert_buffer_size, perform, use_expert, save_networks, network_saving_dir, int(num_timesteps * 1.1), lrschedule=lrschedule)
    env.close()
예제 #3
0
def train(env_id, num_timesteps, seed, policy, lrschedule, num_cpu):
    def make_env(rank):
        def env_fn():
            print(rank)
            if num_cpu == 1:
                env = MarioEnv(num_steering_dir=0)
            else:
                env = MarioEnv(num_steering_dir=11, num_env=rank)
            env.seed(seed + rank)
            env = bench.Monitor(
                env,
                logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
            gym.logger.setLevel(logging.WARN)
            return env

        return env_fn

    set_global_seeds(seed)
    env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
    if policy == 'cnn':
        policy_fn = AcerCnnPolicy
    elif policy == 'lstm':
        policy_fn = AcerLstmPolicy
    else:
        print("Policy {} not implemented".format(policy))
        return
    learn(policy_fn,
          env,
          seed,
          nsteps=50,
          total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule,
          buffer_size=15000,
          gamma=0.95)
    env.close()
예제 #4
0
def train(env_id, num_timesteps, seed, policy, lr_schedule, num_cpu):
    """
    train an ACER model on atari

    :param env_id: (str) Environment ID
    :param num_timesteps: (int) The total number of samples
    :param seed: (int) The initial seed for training
    :param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
    :param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
                                 'double_linear_con', 'middle_drop' or 'double_middle_drop')
    :param num_cpu: (int) The number of cpu to train on
    """
    env = make_atari_env(env_id, num_cpu, seed)
    if policy == 'cnn':
        policy_fn = AcerCnnPolicy
    elif policy == 'lstm':
        policy_fn = AcerLstmPolicy
    else:
        print("Policy {} not implemented".format(policy))
        return
    learn(policy_fn,
          env,
          seed,
          total_timesteps=int(num_timesteps * 1.1),
          lr_schedule=lr_schedule,
          buffer_size=5000)
    env.close()
예제 #5
0
def train(env_id, num_timesteps, seed, policy, lrschedule, num_cpu):
    env = make_atari_env(env_id, num_cpu, seed)
    if policy == 'cnn':
        policy_fn = AcerCnnPolicy
    elif policy == 'lstm':
        policy_fn = AcerLstmPolicy
    else:
        print("Policy {} not implemented".format(policy))
        return
    learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), lrschedule=lrschedule)
    env.close()
예제 #6
0
def main():
    parser = atari_arg_parser()
    parser.add_argument('--flags', '-f', help="flags cfg file", default=None)
    args = parser.parse_args()

    flags = AcerFlags.from_cfg(args.flags) if args.flags else AcerFlags()
    logger.configure(flags.log_dir)

    env = make_atari_env(args.env, num_env=flags.num_env, seed=flags.seed)

    policy_fn = models.get(args.policy)
    learn(policy_fn, env, flags)

    env.close()
예제 #7
0
def train(env_id, num_timesteps, seed, policy, lrschedule, num_cpu):
    env = make_atari_env(env_id, num_cpu, seed)
    if policy == 'cnn':
        policy_fn = AcerCnnPolicy
    elif policy == 'lstm':
        policy_fn = AcerLstmPolicy
    else:
        print("Policy {} not implemented".format(policy))
        return
    learn(policy_fn,
          env,
          seed,
          total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule)
    env.close()
예제 #8
0
def train(num_timesteps, seed, policy, lrschedule):
    env = gym.make('GazeboSmartBotPincherKinect-v0')
    if policy == 'cnn':
        policy_fn = AcerCnnPolicy
    elif policy == 'lstm':
        policy_fn = AcerLstmPolicy
    else:
        print("Policy {} not implemented".format(policy))
        return
    learn(policy_fn,
          env,
          seed,
          total_timesteps=int(num_timesteps * 1.1),
          lrschedule=lrschedule)
    env.close()
예제 #9
0
def main():
    parser = arg_parser()
    parser.add_argument(
        '--flags',
        '-f',
        help="flags cfg file (will load checkpoint in save dir if found)",
        default=None)
    args = parser.parse_args()

    flags = RogueAcerFlags.from_cfg(
        args.flags) if args.flags else RogueAcerFlags()
    RogueEnv.register(flags)
    logger.configure(flags.log_dir)

    env = make_rogue_env(num_env=flags.num_env, seed=flags.seed)

    set_global_seeds(flags.seed)
    policy_fn = models.get(flags.policy)
    learn(policy_fn, env, flags)

    env.close()