def main(load_path, num_episode): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') n_env = 1 env_id = 'Breakout-v0' envs = [make_env(env_id) for _ in range(n_env)] envs = DummyVecEnv(envs) envs = VecToTensor(envs) policy = Policy(84, 84, 4, envs.action_space.n).to(device) policy.load_state_dict(torch.load(load_path, map_location=device)) policy.eval() for i in tqdm(range(num_episode)): obs = envs.reset() total_rewards = 0 while True: action_logits, values = policy(obs) actions = choose_action(action_logits) next_obs, rewards, dones, info = envs.step(actions) total_rewards += rewards envs.render() if dones: break print('--------------------' + str(total_rewards.item()) + '-------------------') envs.close()
env_id = 'Breakout-v0' envs = [make_env(env_id) for _ in range(n_env)] # envs = DummyVecEnv(envs) # envs = SubprocVecEnv(envs) envs = ShmemVecEnv(envs) envs = VecToTensor(envs) date = datetime.now().strftime('%m_%d_%H_%M') mon_file_name = "./tmp/" + date envs = VecMonitor(envs, mon_file_name) train_policy = Policy(84, 84, 4, envs.action_space.n).to(device) step_policy = Policy(84, 84, 4, envs.action_space.n).to(device) step_policy.load_state_dict(train_policy.state_dict()) step_policy.eval() runner = Runner(envs, step_policy, n_step, gamma) optimizer = optim.RMSprop(train_policy.parameters(), lr=lr, alpha=alpha, eps=epsilon) for i in tqdm(range(num_updates)): mb_obs, mb_rewards, mb_values, mb_actions = runner.run() action_logits, values = train_policy(mb_obs) mb_adv = mb_rewards - mb_values dist = Categorical(logits=action_logits)