def main(): # Create carpole environment and network env = gym.make('CartPole-v0').unwrapped if not os.path.exists(model_path): raise Exception("You should train the DQN first!") net = DQN(n_state=env.observation_space.shape[0], n_action=env.action_space.n, epsilon=epsilon, batch_size=batch_size, model_path=model_path) net.load() net.cuda() reward_list = [] for i in range(episode): s = env.reset() total_reward = 0 while True: # env.render() # Select action and obtain the reward a = net.chooseAction(s) s_, r, finish, _ = env.step(a) total_reward += r if finish: print("Episode: %d \t Total reward: %d \t Eps: %f" % (i, total_reward, net.epsilon)) reward_list.append(total_reward) break s = s_ env.close() print("Testing average reward: ", np.mean(reward_list))
epsilon=epsilon, epsilon_decay=epsilon_decay, update_iter=update_iter, batch_size=batch_size, gamma=gamma, model_path=model_path) net.cuda() net.load() reward_list = [] for i in range(episode): s = env.reset() total_reward = 0 while True: # env.render() # Select action and obtain the reward a = net.chooseAction(s) s_, r, finish, info = env.step(a) # Record the total reward total_reward += r # Revised the reward if finish: # 如果遊戲已結束,則將reward設為0以讓網路收斂 r = 0 else: # ---------------------------------------------------- # 拆解reward,更精準的給予環境需要的訊息 # 1. r1得到的是對於距離的資訊, # -abs term代表鼓勵agent不要去移動車子, # 一直維持在中間才能獲得很高的獎賞!