class QAgent(object): def __init__(self, config, log_dir): self.config = config # dictionary with game-specific information self.log_dir = log_dir # saving directory self.env = gym.make(config['game']) self.ale_lives = None self.replay_memory = Experience( memory_size=config['state_memory'], state_shape=config['state_shape'], dtype=config['state_dtype'] ) self.net = config['q']( batch_size=config['batch_size'], state_shape=config['state_shape']+[config['state_time']], num_actions=config['actions'], summaries=True, **config['q_params'] ) # Disable the target network if needed if not self.config['double_q']: self.net.t_batch = self.net.q_batch with tf.variable_scope('RL_summary'): self.episode = tf.Variable(0.,name='episode') self.training_reward = tf.Variable(0.,name='training_reward') self.validation_reward = tf.Variable(0.,name='validation_reward') tf.summary.scalar(name='training_reward',tensor=self.training_reward) tf.summary.scalar(name='validation_reward',tensor=self.validation_reward) # Create tensorflow variables self.session = tf.Session() self.session.run(tf.global_variables_initializer()) self.update_target_network() self.summaries = tf.summary.merge_all() if log_dir is not None: self.train_writer = tf.summary.FileWriter(self.log_dir, self.session.graph) self.epsilon = 1.0 self.steps = 0 def update_target_network(self): """ Update the parameters of the target network Assigns the parameter of the Q-Network to the Target network :return: """ if self.config['double_q']: self.session.run(self.net.assign_op) def _update_training_reward(self,reward): """ set the value of the training reward. This ensures it is stored and visualised on the tensorboard """ self.session.run(self.training_reward.assign(reward)) def _update_validation_reward(self,reward): """ set the value of the validation reward. This ensures it is stored and visualised on the tensorboard """ self.session.run(self.validation_reward.assign(reward)) def get_training_state(self): """ Get the last state :return: """ return self.replay_memory.get_last_state(self.config['state_time']) def sample_action(self,state,epsilon): """ Sample an action for the state according to the epsilon greedy strategy :param state: :param epsilon: :return: """ if np.random.rand() <= epsilon: return np.random.randint(0,self.config['actions']) else: return self.session.run(self.net.q, feed_dict={self.net.x: state[np.newaxis].astype(np.float32)})[0].argmax() def update_state(self, old_state, new_frame): """ :param old_state: :param new_frame: :return: """ # Remove oldest frame and add new frame (84 x 84 x 1) ''' if np.any(old_state[:, -6:, :] == 0.0) or np.any(new_frame == 0.0): return np.concatenate([ old_state[:, 12:, :], new_frame - old_state[:, -6:, :], new_frame], axis=1) ''' return np.concatenate([ old_state[:, 6:, :], new_frame], axis=1) def reset_to_zero_state(self): """ Reset the state history to zeros and reset the environment to fill the first state :return: """ return np.concatenate([ np.zeros((1, self.config['state_shape'][1] - 6, 1)), self.config['frame'](self.env.reset()) ],axis=1) def update_epsilon_and_steps(self): if self.steps > self.config['step_startrl']: self.epsilon = max(self.config['eps_minval'],self.epsilon*self.config['step_eps_mul']-self.config['step_eps_min']) self.steps += 1 def train_episode(self): # Load the last state and add a reset state = self.update_state( self.get_training_state(), self.config['frame'](self.env.reset()) ) # Store the starting state in memory self.replay_memory.add( state = state[:,:,-1], action = np.random.randint(self.config['actions']), reward = 0., done = False ) # Use flags to signal the end of the episode and for pressing the start button done = False press_fire = True total_reward = 0 while not done: if press_fire: # start the episode press_fire = False new_frame,reward,done, info = self.act(state,-1,True) if info.has_key('ale.lives'): self.ale_lives = info['ale.lives'] else: self.update_epsilon_and_steps() new_frame,reward,done, info = self.act(state,self.epsilon,True) state = self.update_state(state, new_frame) #print(state.reshape(72, )) #print(state.reshape((24, ))) total_reward += reward # Perform online Q-learning updates if self.steps > self.config['step_startrl']: summaries,_ = self.train_batch() if self.steps % self.config['tensorboard_interval'] == 0: self.train_writer.add_summary(summaries, global_step=self.steps) if self.steps % self.config['double_q_freq'] == 0.: print("double q swap") self.update_target_network() return total_reward def validate_episode(self,epsilon,visualise=False): state = self.reset_to_zero_state() done = False press_fire = True total_reward = 0 while not done: if press_fire: press_fire = False new_frame, reward, done,_ = self.act(state=state, epsilon=-1, store=False) else: new_frame, reward, done,_ = self.act(state=state, epsilon=epsilon, store=False) state = self.update_state(old_state=state, new_frame=new_frame) if reward == 0.1: reward = 0 total_reward += reward if visualise: self.env.render() return total_reward def act(self, state, epsilon, store=False, debug=False, rs=False): """ Perform an action in the environment. If it is an atari game and there are lives, it will end the episode in the replay memory if a life is lost. This is important in games lik :param epsilon: the epsilon for the epsilon-greedy strategy. If epsilon is -1, the no-op will be used in the atari games. :param state: the state for which to compute the action :param store: if true, the state is added to the replay memory :return: the observed state (processed), the reward and whether the state is final """ if epsilon == -1: action = 1 else: action = self.sample_action(state=state,epsilon=epsilon) raw_frame, reward, done, info = self.env.step(action) # Clip rewards to -1,0,1 reward = np.sign(reward) # Preprocess the output state new_frame = self.config['frame'](raw_frame) #print(state[0, 20, 0]) if rs and state[0, 68, 0] > 0 and state[0, 62, 0] > 0 and new_frame[0, 2, 0] > 0: v_before = state[0, 68, 0] - state[0, 62, 0] v_after = new_frame[0, 2, 0] - state[0, 68, 0] #print(v_after) #print(v_before) if v_after < 0 and v_before > 0: reward = 0.1 #print(reward) #self.env.render() if debug: cur_x = new_frame.reshape((6, )) I = preprocess_atari_crop(raw_frame) for idx in range(3): if idx == 0: cv2.circle(I, (int(cur_x[2*idx]), int(cur_x[2 * idx + 1])), 1, (255, 0, 0), -1) elif idx == 1: cv2.circle(I, (int(cur_x[2 * idx]), int(cur_x[2 * idx + 1])), 1, (0, 255, 0), -1) else: cv2.circle(I, (int(cur_x[2 * idx]), int(cur_x[2 * idx + 1])), 1, (0, 0, 255), -1) print(cur_x) print('--------------------') cv2.imshow('Image', I) cv2.waitKey(0) # If needed, store the last frame if store: # End episodes if lives are involved in the atari game. # This is important in e.g. breakout # If this is not included, dropping the ball is not penalised. # By marking the end of the reward propagation, the maximum reward is limited # This makes learning faster. store_done = done if self.ale_lives is not None and info.has_key('ale.lives') and info['ale.lives']<self.ale_lives: store_done = True self.ale_lives = info['ale.lives'] self.replay_memory.add(state[:,:,-1],action,reward,store_done) return new_frame, reward, done, info def train_batch(self): """ Sample a batch of training samples from the replay memory. Compute the target Q values Perform one SGD update step :return: summaries, step summaries: the tensorflow summaries that can be put into a log. step, the global step from tensorflow. This represents the number of updates """ # Sample experience xp_states, xp_actions, xp_rewards, xp_done, xp_next = self.replay_memory.sample_experience( self.config['batch_size'], self.config['state_time'] ) # Create a mask on which output to update q_mask = np.zeros((self.config['batch_size'], self.config['actions'])) for idx, a in enumerate(xp_actions): q_mask[idx, a] = 1 # Use the target network to value the next states next_actions, next_values = self.session.run( [self.net.q_batch,self.net.t_batch], feed_dict={self.net.x_batch: xp_next.astype(np.float32)} ) # Combine the reward and the value of the next state into the q-targets q_next = np.array([ next_values[idx,next_actions[idx].argmax()] for idx in range(self.config['batch_size']) ]) q_next *= (1.-xp_done) q_targets = (xp_rewards + self.config['gamma']*q_next) # Perform the update feed = { self.net.x_batch: xp_states.astype(np.float32), self.net.q_targets: q_targets[:,np.newaxis]*q_mask, self.net.q_mask: q_mask } _, summaries, step = self.session.run([self.net._train_op, self.summaries, self.net.global_step], feed_dict=feed) return summaries, step
class QAgent(object): def __init__(self, config, log_dir): self.config = config # dictionary with game-specific information self.log_dir = log_dir # saving directory self.env = gym.make(config['game']) self.ale_lives = None # Instantiate replay buffer self.replay_memory = Experience( memory_size=config['state_memory'], state_shape=config['state_shape'], dtype=config['state_dtype'] ) # Instantiate the Q-networks (Running & Target) self.net = config['q']( batch_size=config['batch_size'], state_shape=config['state_shape']+[config['state_time']], num_actions=config['actions'], summaries=True, **config['q_params'] ) # Instantiate class-wise models self.model_pro = config['model']( batch_size=config['batch_size'], state_shape=config['model_state_shape']+[config['state_time']], output_state_shape=3, name='Model_pro', summaries=True, **config['q_params'] ) self.model_ball = config['model']( batch_size=config['batch_size'], state_shape=config['model_state_shape']+[config['state_time']], output_state_shape=3, name='Model_ball', summaries=True, **config['q_params'] ) self.model_ant = config['model']( batch_size=config['batch_size'], state_shape=config['model_state_shape']+[config['state_time']], output_state_shape=3, name='Model_ant', summaries=True, **config['q_params'] ) # Disable the target network if needed if not self.config['double_q']: self.net.t_batch = self.net.q_batch with tf.variable_scope('RL_summary'): self.episode = tf.Variable(0.,name='episode') self.training_reward = tf.Variable(0.,name='training_reward') self.validation_reward = tf.Variable(0.,name='validation_reward') tf.summary.scalar(name='training_reward',tensor=self.training_reward) tf.summary.scalar(name='validation_reward',tensor=self.validation_reward) # Create tensorflow variables self.session = tf.Session() self.session.run(tf.global_variables_initializer()) self.update_target_network() self.summaries = tf.summary.merge_all() if log_dir is not None: self.train_writer = tf.summary.FileWriter(self.log_dir, self.session.graph) self.epsilon = 1.0 self.steps = 0 def update_target_network(self): """ Update the parameters of the target network Assigns the parameter of the Q-Network to the Target network :return: """ if self.config['double_q']: self.session.run(self.net.assign_op) def _update_training_reward(self,reward): """ set the value of the training reward. This ensures it is stored and visualised on the tensorboard """ self.session.run(self.training_reward.assign(reward)) def _update_validation_reward(self,reward): """ set the value of the validation reward. This ensures it is stored and visualised on the tensorboard """ self.session.run(self.validation_reward.assign(reward)) def get_training_state(self): """ Get the last state :return: """ return self.replay_memory.get_last_state(self.config['state_time']) def sample_action(self,state,epsilon): """ Sample an action for the state according to the epsilon greedy strategy :param state: :param epsilon: :return: """ if np.random.rand() <= epsilon: return np.random.randint(0,self.config['actions']) else: return self.session.run(self.net.q, feed_dict={self.net.x: state[np.newaxis].astype(np.float32)})[0].argmax() def update_state(self, old_state, new_frame): """ :param old_state: :param new_frame: :return: """ # Remove oldest frame and add new frame (84 x 84 x 1) ''' if np.any(old_state[:, -6:, :] == 0.0) or np.any(new_frame == 0.0): return np.concatenate([ old_state[:, 12:, :], new_frame - old_state[:, -6:, :], new_frame], axis=1) ''' return np.concatenate([ old_state[:, 9:, :], new_frame], axis=1) def reset_to_zero_state(self): """ Reset the state history to zeros and reset the environment to fill the first state :return: """ return np.concatenate([ np.zeros((1, self.config['state_shape'][1] - 9, 1)), self.config['frame'](self.env.reset()) ],axis=1) def reset_to_zero_state_model(self): """ Reset the state history to zeros and reset the environment to fill the first state :return: """ state = self.env.reset() for idx in range(15): state, _, _, _ = self.env.step(1) return np.concatenate([ np.zeros((1, self.config['state_shape'][1] - 9, 1)), self.config['frame'](state) ],axis=1) def update_epsilon_and_steps(self): if self.steps > self.config['step_startrl']: self.epsilon = max(self.config['eps_minval'],self.epsilon*self.config['step_eps_mul']-self.config['step_eps_min']) self.steps += 1 def train_episode(self): # Load the last state and add a reset state = self.update_state( self.get_training_state(), self.config['frame'](self.env.reset()) ) # Store the starting state in memory self.replay_memory.add( state = state[:,:,-1], action = np.random.randint(self.config['actions']), reward = 0., done = False ) # Use flags to signal the end of the episode and for pressing the start button done = False press_fire = True total_reward = 0 agg_loss = np.zeros((3, )) loop_count = 0 while not done: if press_fire: # start the episode press_fire = False new_frame,reward,done, info = self.act(state,-1,True) if info.has_key('ale.lives'): self.ale_lives = info['ale.lives'] else: self.update_epsilon_and_steps() new_frame,reward,done, info = self.act(state,self.epsilon,True) state = self.update_state(state, new_frame) #print(state.reshape(72, )) #print(state.reshape((24, ))) total_reward += reward # Perform online Q-learning updates if self.steps > self.config['step_startrl']: loop_count += 1 summaries,_,loss = self.train_batch(loop_count) for dummy_idx in range(25): self.train_model_batch(loop_count) agg_loss += np.array(loss).reshape((3, )) if self.steps % self.config['tensorboard_interval'] == 0: self.train_writer.add_summary(summaries, global_step=self.steps) if self.steps % self.config['double_q_freq'] == 0.: print("double q swap") self.update_target_network() if loop_count == 0: return total_reward, agg_loss else: return total_reward, agg_loss/loop_count def validate_episode(self,epsilon,visualise=False): state = self.reset_to_zero_state() #model_state = state done = False press_fire = True total_reward = 0 while not done: if press_fire: press_fire = False new_frame, reward, done,_ = self.act(state=state, epsilon=-1, store=False) else: new_frame, reward, done,_ = self.act(state=state, epsilon=epsilon, store=False, debug=False) #model_frame, new_frame, reward, done,_ = self.act(model_state=model_state, state=state, epsilon=epsilon, store=False, debug=True) state = self.update_state(old_state=state, new_frame=new_frame) #model_state = self.update_state(old_state=model_state, new_frame=model_frame) if reward == 0.1: reward = 0 total_reward += reward if visualise: self.env.render() return total_reward def act(self, state, epsilon, store=False, debug=False, rs=False): """ Perform an action in the environment. If it is an atari game and there are lives, it will end the episode in the replay memory if a life is lost. This is important in games lik :param epsilon: the epsilon for the epsilon-greedy strategy. If epsilon is -1, the no-op will be used in the atari games. :param state: the state for which to compute the action :param store: if true, the state is added to the replay memory :return: the observed state (processed), the reward and whether the state is final """ if epsilon == -1: action = 1 else: action = self.sample_action(state=state,epsilon=epsilon) raw_frame, reward, done, info = self.env.step(action) # Clip rewards to -1,0,1 reward = np.sign(reward) # Preprocess the output state new_frame = self.config['frame'](raw_frame) if debug: action_input = np.zeros((1, 1, self.config['actions'], 1)) action_input[0, 0, action, 0] = 1 m_state = state[np.newaxis].astype(np.float32) m_state = m_state[:, :, -108:, :] model_input = np.concatenate((m_state, action_input.astype(np.float32)), axis=2) # Class-wise state prediction model_ant_prediction, model_ball_prediction, model_pro_prediction = self.session.run([self.model_ant.next_state, self.model_ball.next_state, self.model_pro.next_state], feed_dict={self.model_ant.x: model_input, self.model_ball.x: model_input, self.model_pro.x: model_input}) model_ant_prediction = self.update_presence(model_ant_prediction) model_pro_prediction = self.update_presence(model_pro_prediction) model_ball_prediction = self.update_presence(model_ball_prediction) # Concatenate into a joint model state model_prediction = np.concatenate((model_ant_prediction, model_ball_prediction, model_pro_prediction), axis=1) cur_x = new_frame.reshape((self.config['actions']+3, )) model_predicted_state = model_prediction.reshape((9, )) I = preprocess_atari_crop(raw_frame) for idx in range(3): # Display opponent paddle in Blue (opencv is BGR format) if idx == 0 and model_predicted_state[3*idx+2] == 1: cv2.rectangle(I, (np.round((model_predicted_state[3*idx]+1)*79.5).astype(int) - 2, np.round((model_predicted_state[3*idx+1]+1)*79.5).astype(int) - 8), (np.round((model_predicted_state[3*idx]+1)*79.5).astype(int) + 2, np.round((model_predicted_state[3*idx+1]+1)*79.5).astype(int) + 8), (255, 0, 0), 3) # Display ball in Green elif idx == 1 and model_predicted_state[3*idx+2] == 1: cv2.rectangle(I, (np.round((model_predicted_state[3*idx]+1)*79.5).astype(int) - 1, np.round((model_predicted_state[3*idx+1]+1)*79.5).astype(int) - 2), (np.round((model_predicted_state[3*idx]+1)*79.5).astype(int) + 1, np.round((model_predicted_state[3*idx+1]+1)*79.5).astype(int) + 2), (0, 255, 0), 3) # Display your paddle in Red elif idx == 2 and model_predicted_state[3*idx+2] == 1: cv2.rectangle(I, (np.round((model_predicted_state[3*idx]+1)*79.5).astype(int) - 2, np.round((model_predicted_state[3*idx+1]+1)*79.5).astype(int) - 8), (np.round((model_predicted_state[3*idx]+1)*79.5).astype(int) + 2, np.round((model_predicted_state[3*idx+1]+1)*79.5).astype(int) + 8), (0, 0, 255), 3) cv2.circle(I, (np.round((cur_x[3*idx]+1)*79.5).astype(int), np.round((cur_x[3*idx+1]+1)*79.5).astype(int)), 1, (255, 255, 255), -1) print(cur_x) print(model_predicted_state) print('--------------------') I = cv2.resize(I, None, fx=2, fy=2, interpolation=cv2.INTER_LINEAR) cv2.imshow('Image', I) cv2.waitKey(0) # If needed, store the last frame if store: # End episodes if lives are involved in the atari game. # This is important in e.g. breakout # If this is not included, dropping the ball is not penalised. # By marking the end of the reward propagation, the maximum reward is limited # This makes learning faster. store_done = done if self.ale_lives is not None and info.has_key('ale.lives') and info['ale.lives']<self.ale_lives: store_done = True self.ale_lives = info['ale.lives'] self.replay_memory.add(state[:,:,-1],action,reward,store_done) return new_frame, reward, done, info def validate_model(self, epsilon): state = self.reset_to_zero_state_model() #model_state = state done = False press_fire = True total_reward = 0 count = 0 while not done: if count > 400: done = True elif press_fire: press_fire = False new_frame = self.act_with_model(state=state, epsilon=-1, debug=True) else: new_frame = self.act_with_model(state=state, epsilon=epsilon, debug=True) if new_frame[:, 2, :] == 0: state = self.reset_to_zero_state_model() else: state = self.update_state(old_state=state, new_frame=new_frame) count += 1 return def act_with_model(self, state, epsilon, debug): if epsilon == -1: action = 1 else: action = self.sample_action(state=state,epsilon=epsilon) action_input = np.zeros((1, 1, self.config['actions'], 1)) action_input[0, 0, action, 0] = 1 state = state[np.newaxis].astype(np.float32) state = state[:, :, -108:, :] model_input = np.concatenate((state, action_input.astype(np.float32)), axis=2) # Class-wise state prediction model_ant_prediction, model_ball_prediction, model_pro_prediction = self.session.run([self.model_ant.next_state, self.modedl_ball.next_state, self.model_pro.next_state], feed_dict={self.model_ant.x: model_input, self.model_ball.x: model_input, self.model_pro.x: model_input}) model_ant_prediction = self.update_presence(model_ant_prediction) model_pro_prediction = self.update_presence(model_pro_prediction) model_ball_prediction = self.update_presence(model_ball_prediction) # Concatenate into a joint model state model_prediction = np.concatenate((model_ant_prediction, model_ball_prediction, model_pro_prediction), axis=1) if debug: model_predicted_state = model_prediction.reshape((9, )) I = np.zeros((160, 160, 3)) for idx in range(3): # Display opponent paddle in Blue (opencv is BGR format) if idx == 0 and model_predicted_state[3*idx+2] == 1: cv2.rectangle(I, (np.round((model_predicted_state[3*idx]+1)*79.5).astype(int) - 2, np.round((model_predicted_state[3*idx+1]+1)*79.5).astype(int) - 8), (np.round((model_predicted_state[3*idx]+1)*79.5).astype(int) + 2, np.round((model_predicted_state[3*idx+1]+1)*79.5).astype(int) + 8), (255, 0, 0), 3) # Display ball in Green elif idx == 1 and model_predicted_state[3*idx+2] == 1: cv2.rectangle(I, (np.round((model_predicted_state[3*idx]+1)*79.5).astype(int) - 1, np.round((model_predicted_state[3*idx+1]+1)*79.5).astype(int) - 2), (np.round((model_predicted_state[3*idx]+1)*79.5).astype(int) + 1, np.round((model_predicted_state[3*idx+1]+1)*79.5).astype(int) + 2), (0, 255, 0), 3) # Display your paddle in Red elif idx == 2 and model_predicted_state[3*idx+2] == 1: cv2.rectangle(I, (np.round((model_predicted_state[3*idx]+1)*79.5).astype(int) - 2, np.round((model_predicted_state[3*idx+1]+1)*79.5).astype(int) - 8), (np.round((model_predicted_state[3*idx]+1)*79.5).astype(int) + 2, np.round((model_predicted_state[3*idx+1]+1)*79.5).astype(int) + 8), (0, 0, 255), 3) ''' cv2.circle(I, (int((model_predicted_state[3*idx]+1)*79.5), int((model_predicted_state[3*idx+1]+1)*79.5)), 1, (255, 255, 255), -1) ''' I = cv2.resize(I, None, fx=2, fy=2, interpolation=cv2.INTER_LINEAR) print(model_predicted_state) print('--------------------') cv2.imshow('Image', I) cv2.waitKey(0) model_prediction = model_prediction.reshape((1, 9, 1)) #model_prediction = np.round((model_prediction + 1)*79.5) / 79.5 - 1.0 return model_prediction def update_presence(self, prediction): if prediction[0, 2] > 0: prediction[0, 2] = 1 else: prediction[0, 2] = 0 return prediction def train_batch(self, loop_count): """ Sample a batch of training samples from the replay memory. Compute the target Q values Perform one SGD update step :return: summaries, step summaries: the tensorflow summaries that can be put into a log. step, the global step from tensorflow. This represents the number of updates """ # Sample experience xp_states, xp_actions, xp_rewards, xp_done, xp_next = self.replay_memory.sample_experience( self.config['batch_size'], self.config['state_time'] ) # Create a mask on which output to update q_mask = np.zeros((self.config['batch_size'], self.config['actions'])) action_input = np.zeros((self.config['batch_size'], 1, self.config['actions'], 1)) for idx, a in enumerate(xp_actions): q_mask[idx, a] = 1 action_input[idx, 0, a, 0] = 1 # Use the target network to value the next states next_actions, next_values = self.session.run( [self.net.q_batch,self.net.t_batch], feed_dict={self.net.x_batch: xp_next.astype(np.float32)} ) # Combine the reward and the value of the next state into the q-targets q_next = np.array([ next_values[idx,next_actions[idx].argmax()] for idx in range(self.config['batch_size']) ]) q_next *= (1.-xp_done) q_targets = (xp_rewards + self.config['gamma']*q_next) ''' # Perform the update feed = { self.net.x_batch: xp_states.astype(np.float32), self.net.q_targets: q_targets[:,np.newaxis]*q_mask, self.net.q_mask: q_mask } _, summaries, step = self.session.run([self.net._train_op, self.summaries, self.net.global_step], feed_dict=feed) ''' # Perform update on class-wise model # State batch = same for all models xp_states = xp_states.astype(np.float32) m_batch = np.concatenate((xp_states[:, :, -108:, :], action_input.astype(np.float32)), axis=2) # Joint next state batch m_targets = xp_next.astype(np.float32) # Antagonist model target m_targets_ant = m_targets[:, :, -9:-6, :] m_targets_ant = m_targets_ant.reshape((self.config['batch_size'], 3)) # Ball model target m_targets_ball = m_targets[:, :, -6:-3, :] m_targets_ball = m_targets_ball.reshape((self.config['batch_size'], 3)) # Protagonist model target m_targets_pro = m_targets[:, :, -3:, :] m_targets_pro = m_targets_pro.reshape((self.config['batch_size'], 3)) feed = { self.net.x_batch: xp_states.astype(np.float32), self.net.q_targets: q_targets[:,np.newaxis]*q_mask, self.net.q_mask: q_mask, self.model_ant.x_batch: m_batch, self.model_ant.m_targets: m_targets_ant, self.model_ball.x_batch: m_batch, self.model_ball.m_targets: m_targets_ball, self.model_pro.x_batch: m_batch, self.model_pro.m_targets: m_targets_pro } ''' feed_model = { self.model.x_batch: np.concatenate((xp_states.astype(np.float32), action_input.astype(np.float32)), axis=2), self.model.m_targets: m_targets } ''' # Train all networks _, _, _, _, next_states, loss_ant, loss_ball, loss_pro, summaries, step, _, _, _ = self.session.run([self.net._train_op, self.model_ant._train_op, self.model_ball._train_op, self.model_pro._train_op, self.model_ball.next_state_batch, self.model_ant.loss, self.model_ball.loss, self.model_pro.loss, self.summaries, self.net.global_step, self.model_ant.global_step, self.model_ball.global_step, self.model_pro.global_step], feed_dict=feed) if loop_count == 500: #m_batch = m_batch.reshape((self.config['batch_size'], self.config['model_state_shape'][1])) #print(m_batch[0:10, :]) print(m_targets[0:5, :, -9:, :].reshape((5, 9))) print(next_states[0:5, :]) loss = [loss_ant, loss_ball, loss_pro] return summaries, step, loss def train_model_batch(self, loop_count): """ Sample a batch of training samples from the replay memory. Compute the target Q values Perform one SGD update step :return: summaries, step summaries: the tensorflow summaries that can be put into a log. step, the global step from tensorflow. This represents the number of updates """ # Sample experience xp_states, xp_actions, xp_rewards, xp_done, xp_next = self.replay_memory.sample_experience( self.config['batch_size'], self.config['state_time'] ) # Create a mask on which output to update action_input = np.zeros((self.config['batch_size'], 1, self.config['actions'], 1)) for idx, a in enumerate(xp_actions): action_input[idx, 0, a, 0] = 1 # Perform update on class-wise model # State batch = same for all models xp_states = xp_states.astype(np.float32) m_batch = np.concatenate((xp_states[:, :, -108:, :], action_input.astype(np.float32)), axis=2) # Joint next state batch m_targets = xp_next.astype(np.float32) # Antagonist model target m_targets_ant = m_targets[:, :, -9:-6, :] m_targets_ant = m_targets_ant.reshape((self.config['batch_size'], 3)) # Ball model target m_targets_ball = m_targets[:, :, -6:-3, :] m_targets_ball = m_targets_ball.reshape((self.config['batch_size'], 3)) # Protagonist model target m_targets_pro = m_targets[:, :, -3:, :] m_targets_pro = m_targets_pro.reshape((self.config['batch_size'], 3)) feed = { self.model_ant.x_batch: m_batch, self.model_ant.m_targets: m_targets_ant, self.model_ball.x_batch: m_batch, self.model_ball.m_targets: m_targets_ball, self.model_pro.x_batch: m_batch, self.model_pro.m_targets: m_targets_pro } ''' feed_model = { self.model.x_batch: np.concatenate((xp_states.astype(np.float32), action_input.astype(np.float32)), axis=2), self.model.m_targets: m_targets } ''' # Train all networks _, _, _, next_states, loss_ant, loss_ball, loss_pro = self.session.run([self.model_ant._train_op, self.model_ball._train_op, self.model_pro._train_op, self.model_ball.next_state_batch, self.model_ant.loss, self.model_ball.loss, self.model_pro.loss], feed_dict=feed) loss = [loss_ant, loss_ball, loss_pro] return loss