def main(game_name, lr, num_agents, update_target_every, model_name, tau): assert 'NoFrameskip-v4' in game_name if 'soft' in model_name: update_target_every = 1 basename = '{}:lr={}:na={}:ute={}:{}'.format( game_name[:-14], lr, num_agents, update_target_every, model_name) if 'soft' in model_name: basename += ':tau={}'.format(tau) env = Agent(num_agents, game_name, basename) try: estimator = get_estimator(model_name, env.action_n, lr, 0.99, tau=tau) base_path = os.path.join(train_path, basename) print("start training!!") dqn(env, estimator, base_path, batch_size=32, epsilon=0.01, save_model_every=1000, update_target_every=update_target_every, learning_starts=200, memory_size=100000, num_iterations=40000000) except KeyboardInterrupt: print("\nKeyboard interrupt!!") except Exception: traceback.print_exc() finally: env.close()
def main(game_name, lr, num_agents, update_target_every, model_name): assert 'NoFrameskip-v4' in game_name basename = '{}:lr={}:na={}:ute={}:{}'.format(game_name[:-14], lr, num_agents, update_target_every, model_name) env = Agent(num_agents, game_name, basename) try: estimator = get_estimator(model_name, env.action_n, lr, 0.99) base_path = os.path.join(train_path, basename) print("start training!!") dqn(env, estimator, base_path, batch_size=32, epsilon=0.01, save_model_every=1000, update_target_every=update_target_every, learning_starts=200, memory_size=100000, num_iterations=40000000) except KeyboardInterrupt: print("\nKeyboard interrupt!!") except Exception: traceback.print_exc() finally: env.close()
from estimator import get_estimator from data import get_input_fn if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='training script') parser.add_argument('model_name', nargs='?', type=str, default='base') parser.add_argument('--batch_size', '-b', nargs='?', type=int, default=64) parser.add_argument('--max_steps', '-s', nargs='?', type=int, default=1e6) args = parser.parse_args() estimator = get_estimator(args.model_name) input_fn = get_input_fn(args.batch_size, shuffle=True) estimator.train(input_fn, max_steps=args.max_steps)
def main(game_name, model_name, write_video): assert 'NoFrameskip-v4' in game_name env = atari_env(game_name) estimator = get_estimator(model_name, env.action_space.n, 0.001, 0.99) basename_list = [ name for name in os.listdir(train_path) if (game_name[:-14] in name) and (model_name in name) ] print(basename_list) def visualize(basename): checkpoint_path = os.path.join(train_path, basename, 'models') estimator.load_model(checkpoint_path) total_t = estimator.get_global_step() if not os.path.exists('./videos'): os.makedirs('./videos') videoWriter = imageio.get_writer('./videos/{}-{}.mp4'.format( basename, total_t), fps=30) state = env.reset(videowriter=videoWriter) lives = env.unwrapped.ale.lives() print(lives) r = 0 tot = 0 while True: action = estimator.get_action(np.array([state]), 0.0) state, reward, done, info = env.step(action) r += reward tot += 1 if done: lives = env.unwrapped.ale.lives() print(lives) if info['was_real_done']: print(tot, r) break else: state = env.reset() videoWriter.close() def evaluate(basename, num_eval=10): checkpoint_path = os.path.join(train_path, basename, 'models') estimator.load_model(checkpoint_path) res = [] for i in tqdm(range(num_eval)): env.seed(int(time.time() * 1000) // 2147483647) state = env.reset() r = 0 while True: action = estimator.get_action(np.array([state]), 0.0) state, reward, done, info = env.step(action) r += reward if done: if info['was_real_done']: res.append(r) break else: state = env.reset() print('mean: {}, max: {}'.format(sum(res) / num_eval, max(res))) if write_video: for basename in basename_list: print("Writing {}'s video ...".format(basename)) visualize(basename) else: for basename in basename_list: print("Evaluating {} ...".format(basename)) evaluate(basename)