Пример #1
0
def main(game_name, lr, num_agents, update_target_every, model_name, tau):
    assert 'NoFrameskip-v4' in game_name

    if 'soft' in model_name:
        update_target_every = 1

    basename = '{}:lr={}:na={}:ute={}:{}'.format(
        game_name[:-14], lr, num_agents, update_target_every, model_name)

    if 'soft' in model_name:
        basename += ':tau={}'.format(tau)

    env = Agent(num_agents, game_name, basename)
    try:
        estimator = get_estimator(model_name, env.action_n, lr, 0.99, tau=tau)
        base_path = os.path.join(train_path, basename)
        print("start training!!")
        dqn(env,
            estimator,
            base_path,
            batch_size=32,
            epsilon=0.01,
            save_model_every=1000,
            update_target_every=update_target_every,
            learning_starts=200,
            memory_size=100000,
            num_iterations=40000000)
    except KeyboardInterrupt:
        print("\nKeyboard interrupt!!")
    except Exception:
        traceback.print_exc()
    finally:
        env.close()
Пример #2
0
def main(game_name, lr, num_agents, update_target_every, model_name):
    assert 'NoFrameskip-v4' in game_name

    basename = '{}:lr={}:na={}:ute={}:{}'.format(game_name[:-14], lr,
                                                 num_agents,
                                                 update_target_every,
                                                 model_name)

    env = Agent(num_agents, game_name, basename)
    try:
        estimator = get_estimator(model_name, env.action_n, lr, 0.99)
        base_path = os.path.join(train_path, basename)
        print("start training!!")
        dqn(env,
            estimator,
            base_path,
            batch_size=32,
            epsilon=0.01,
            save_model_every=1000,
            update_target_every=update_target_every,
            learning_starts=200,
            memory_size=100000,
            num_iterations=40000000)
    except KeyboardInterrupt:
        print("\nKeyboard interrupt!!")
    except Exception:
        traceback.print_exc()
    finally:
        env.close()
Пример #3
0
from estimator import get_estimator
from data import get_input_fn

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='training script')
    parser.add_argument('model_name', nargs='?', type=str, default='base')
    parser.add_argument('--batch_size', '-b', nargs='?', type=int, default=64)
    parser.add_argument('--max_steps', '-s', nargs='?', type=int, default=1e6)
    args = parser.parse_args()

    estimator = get_estimator(args.model_name)
    input_fn = get_input_fn(args.batch_size, shuffle=True)
    estimator.train(input_fn, max_steps=args.max_steps)
Пример #4
0
def main(game_name, model_name, write_video):
    assert 'NoFrameskip-v4' in game_name
    env = atari_env(game_name)

    estimator = get_estimator(model_name, env.action_space.n, 0.001, 0.99)

    basename_list = [
        name for name in os.listdir(train_path)
        if (game_name[:-14] in name) and (model_name in name)
    ]
    print(basename_list)

    def visualize(basename):
        checkpoint_path = os.path.join(train_path, basename, 'models')
        estimator.load_model(checkpoint_path)

        total_t = estimator.get_global_step()
        if not os.path.exists('./videos'):
            os.makedirs('./videos')
        videoWriter = imageio.get_writer('./videos/{}-{}.mp4'.format(
            basename, total_t),
                                         fps=30)

        state = env.reset(videowriter=videoWriter)
        lives = env.unwrapped.ale.lives()
        print(lives)
        r = 0
        tot = 0
        while True:
            action = estimator.get_action(np.array([state]), 0.0)
            state, reward, done, info = env.step(action)
            r += reward
            tot += 1
            if done:
                lives = env.unwrapped.ale.lives()
                print(lives)
                if info['was_real_done']:
                    print(tot, r)
                    break
                else:
                    state = env.reset()
        videoWriter.close()

    def evaluate(basename, num_eval=10):
        checkpoint_path = os.path.join(train_path, basename, 'models')
        estimator.load_model(checkpoint_path)

        res = []
        for i in tqdm(range(num_eval)):
            env.seed(int(time.time() * 1000) // 2147483647)
            state = env.reset()
            r = 0
            while True:
                action = estimator.get_action(np.array([state]), 0.0)
                state, reward, done, info = env.step(action)
                r += reward
                if done:
                    if info['was_real_done']:
                        res.append(r)
                        break
                    else:
                        state = env.reset()
        print('mean: {}, max: {}'.format(sum(res) / num_eval, max(res)))

    if write_video:
        for basename in basename_list:
            print("Writing {}'s video ...".format(basename))
            visualize(basename)
    else:
        for basename in basename_list:
            print("Evaluating {} ...".format(basename))
            evaluate(basename)