def train_dqn(self): # get size of state and action from environment state_size = 4 action_size = 2 agent = DQNAgent(state_size, action_size) scores, episodes = [], [] for episode in range(1, EPISODES + 1): # run reinforcement learning for every episode done = False score = 0 self.reset() state, _, _, _ = self.step(-1) state = np.reshape(state, [1, state_size]) rospy.loginfo("Episode %d: starting", episode) while not done: action = agent.get_action(state) next_state, reward, done, _, = self.step(action) rospy.loginfo("Episode %d: action: %d pitch: %f", episode, action, next_state[0]) next_state = np.reshape(next_state, [1, state_size]) # if an action make the episode end, then gives penalty of -100 reward = reward if not done or score == 499 else -100 # save the sample <s, a, r, s'> to the replay memory agent.append_sample(state, action, reward, next_state, done) # every time step do the training agent.train_model() score += reward state = next_state if done or score >= 500: # every episode update the target model to be same with model agent.update_target_model() # every episode, plot the play time score = score if score == 500 else score + 100 scores.append(score) episodes.append(episode) pylab.plot(episodes, scores, "b") pylab.savefig("./cartpole_dqn.png") print( "episode:", episode, " score:", score, " memory length:", len(agent.memory), " epsilon:", agent.epsilon, ) break # if the mean of scores of last 10 episode is bigger than 490 # stop training if np.mean(scores[-min(10, len(scores)):]) > 490: break # save the model if episode % 20 == 0: agent.model.save_weights("./cartpole_dqn.h5") rospy.loginfo("Episode %d: completed", episode)