def run_DDQN(index, env): with tf.variable_scope('DDQN_' + str(index)): agent = DQfDDDQN(env, DDQNConfig()) scores = [] for e in range(Config.episode): done = False score = 0 # sum of reward in one episode state = env.reset() while done is False: action = agent.egreedy_action(state) # e-greedy action for train next_state, reward, done, _ = env.step(action) score += reward reward = reward if not done or score == 499 else -100 agent.perceive([state, action, reward, next_state, done, 0.0]) # 0. means it is not a demo data agent.train_Q_network(update=False) state = next_state if done: scores.append(score) agent.sess.run(agent.update_target_net) print("episode:", e, " score:", score, " demo_buffer:", len(agent.demo_buffer), " memory length:", len(agent.replay_buffer), " epsilon:", agent.epsilon) # if np.mean(scores[-min(10, len(scores)):]) > 490: # break return scores
def get_demo_data(env): # env = wrappers.Monitor(env, '/tmp/CartPole-v0', force=True) # agent.restore_model() with tf.variable_scope('get_demo_data'): agent = DQfDDDQN(env, DDQNConfig()) e = 0 while True: done = False score = 0 # sum of reward in one episode state = env.reset() demo = [] while done is False: action = agent.egreedy_action(state) # e-greedy action for train next_state, reward, done, _ = env.step(action) score += reward reward = reward if not done or score == 499 else -100 agent.perceive([state, action, reward, next_state, done, 0.0]) # 0. means it is not a demo data demo.append([state, action, reward, next_state, done, 1.0]) # record the data that could be expert-data agent.train_Q_network(update=False) state = next_state if done: if score == 500: # expert demo data demo = set_n_step(demo, Config.trajectory_n) agent.demo_buffer.extend(demo) agent.sess.run(agent.update_target_net) print("episode:", e, " score:", score, " demo_buffer:", len(agent.demo_buffer), " memory length:", len(agent.replay_buffer), " epsilon:", agent.epsilon) if len(agent.demo_buffer) >= Config.demo_buffer_size: agent.demo_buffer = deque(itertools.islice(agent.demo_buffer, 0, Config.demo_buffer_size)) break e += 1 with open(Config.DEMO_DATA_PATH, 'wb') as f: pickle.dump(agent.demo_buffer, f, protocol=pickle.HIGHEST_PROTOCOL)