def main(): args = get_args() nn.set_default_context( get_extension_context(args.extension, device_id=args.device_id)) from atari_utils import make_atari_deepmind env = make_atari_deepmind(args.gym_env, valid=True) print('Observation:', env.observation_space) print('Action:', env.action_space) obs_sampler = ObsSampler(args.num_frames) val_replay_memory = ReplayMemory(env.observation_space.shape, env.action_space.shape, max_memory=args.num_frames) # for one file explorer = GreedyExplorer(env.action_space.n, use_nnp=True, nnp_file=args.nnp, name='qnet') validator = Validator(env, val_replay_memory, explorer, obs_sampler, num_episodes=30, clip_episode_step=True, render=not args.no_render) mean_reward = validator.step() with open(os.path.join(args.log_path, 'mean_reward.txt'), 'a') as f: print("{} {}".format(args.gym_env, str(mean_reward)), file=f)
def main(): args = get_args() nn.set_default_context( get_extension_context(args.extension, device_id=args.device_id)) if args.nnp is None: local_nnp_dir = os.path.join("asset", args.gym_env) local_nnp_file = os.path.join(local_nnp_dir, "qnet.nnp") if not find_local_nnp(args.gym_env): logger.info("Downloading nnp data since you didn't specify...") nnp_uri = os.path.join( "https://nnabla.org/pretrained-models/nnp_models/examples/dqn", args.gym_env, "qnet.nnp") if not os.path.exists(local_nnp_dir): os.mkdir(local_nnp_dir) download(nnp_uri, output_file=local_nnp_file, open_file=False) logger.info("Download done!") args.nnp = local_nnp_file from atari_utils import make_atari_deepmind env = make_atari_deepmind(args.gym_env, valid=False) print('Observation:', env.observation_space) print('Action:', env.action_space) obs_sampler = ObsSampler(args.num_frames) val_replay_memory = ReplayMemory(env.observation_space.shape, env.action_space.shape, max_memory=args.num_frames) # just play greedily explorer = GreedyExplorer(env.action_space.n, use_nnp=True, nnp_file=args.nnp, name='qnet') validator = Validator(env, val_replay_memory, explorer, obs_sampler, num_episodes=1, render=not args.no_render) while True: validator.step()
def main(): args = get_args() device = torch.device('cuda', index=args.device_id) if torch.cuda.is_available( ) else torch.device('cpu') if torch.cuda.is_available(): torch.cuda.set_device(args.device_id) if args.log_path: output_path = OutputPath(args.log_path) else: output_path = OutputPath() # monitor = Monitor(output_path.path) tbw = SummaryWriter(output_path.path) # Create an atari env. from atari_utils import make_atari_deepmind env = make_atari_deepmind(args.gym_env, valid=False) env_val = make_atari_deepmind(args.gym_env, valid=True) print('Observation:', env.observation_space) print('Action:', env.action_space) # 10000 * 4 frames val_replay_memory = ReplayMemory(env.observation_space.shape, env.action_space.shape, max_memory=args.num_frames) replay_memory = ReplayMemory(env.observation_space.shape, env.action_space.shape, max_memory=40000) learner = QLearner(env.action_space.n, device, sync_freq=1000, save_freq=250000, gamma=0.99, learning_rate=1e-4, save_path=output_path) explorer = LinearDecayEGreedyExplorer(env.action_space.n, device, network=learner.get_network(), eps_start=1.0, eps_end=0.01, eps_steps=1e6) sampler = Sampler(args.num_frames) obs_sampler = ObsSampler(args.num_frames) validator = Validator(env_val, val_replay_memory, explorer, obs_sampler, num_episodes=args.num_val_episodes, num_eval_steps=args.num_eval_steps, render=args.render_val, tbw=tbw) trainer_with_validator = Trainer(env, replay_memory, learner, sampler, explorer, obs_sampler, inter_eval_steps=args.inter_eval_steps, num_episodes=args.num_episodes, train_start=10000, batch_size=32, render=args.render_train, validator=validator, tbw=tbw) for e in range(args.num_epochs): trainer_with_validator.step()
def main(): args = get_args() nn.set_default_context( get_extension_context(args.extension, device_id=args.device_id)) if args.log_path: output_path = OutputPath(args.log_path) else: output_path = OutputPath() monitor = Monitor(output_path.path) tbw = SummaryWriter(output_path.path) # Create an atari env. from atari_utils import make_atari_deepmind env = make_atari_deepmind(args.gym_env, valid=False) env_val = make_atari_deepmind(args.gym_env, valid=True) print('Observation:', env.observation_space) print('Action:', env.action_space) # 10000 * 4 frames val_replay_memory = ReplayMemory(env.observation_space.shape, env.action_space.shape, max_memory=args.num_frames) replay_memory = ReplayMemory(env.observation_space.shape, env.action_space.shape, max_memory=40000) learner = QLearner(q_cnn, env.action_space.n, sync_freq=1000, save_freq=250000, gamma=0.99, learning_rate=1e-4, name_q='q', save_path=output_path) explorer = LinearDecayEGreedyExplorer(env.action_space.n, eps_start=1.0, eps_end=0.01, eps_steps=1e6, q_builder=q_cnn, name='q') sampler = Sampler(args.num_frames) obs_sampler = ObsSampler(args.num_frames) validator = Validator(env_val, val_replay_memory, explorer, obs_sampler, num_episodes=args.num_val_episodes, num_eval_steps=args.num_eval_steps, render=args.render_val, monitor=monitor, tbw=tbw) trainer_with_validator = Trainer(env, replay_memory, learner, sampler, explorer, obs_sampler, inter_eval_steps=args.inter_eval_steps, num_episodes=args.num_episodes, train_start=10000, batch_size=32, render=args.render_train, validator=validator, monitor=monitor, tbw=tbw) for e in range(args.num_epochs): trainer_with_validator.step()