def play(): if play_params.ENV == 'Pendulum-v0': play_env = PendulumWrapper() elif play_params.ENV == 'LunarLanderContinuous-v2': play_env = LunarLanderContinuousWrapper() elif play_params.ENV == 'BipedalWalker-v2': play_env = BipedalWalkerWrapper() elif play_params.ENV == 'BipedalWalkerHardcore-v2': play_env = BipedalWalkerWrapper(hardcore=True) else: raise Exception( 'Chosen environment does not have an environment wrapper defined. Please choose an environment with an environment wrapper defined, or create a wrapper for this environment in utils.env_wrapper.py' ) actor_net = Actor(play_params.STATE_DIMS, play_params.ACTION_DIMS, play_params.ACTION_BOUND_LOW, play_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, name='actor_play') critic_net = Critic(play_params.STATE_DIMS, play_params.ACTION_DIMS, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, train_params.NUM_ATOMS, train_params.V_MIN, train_params.V_MAX, name='critic_play') actor_net.load_weights(play_params.ACTOR_MODEL_DIR) critic_net.load_weights(play_params.CRITIC_MODEL_DIR) if not os.path.exists(play_params.RECORD_DIR): os.makedirs(play_params.RECORD_DIR) for ep in tqdm(range(1, play_params.NUM_EPS_PLAY + 1), desc='playing'): state = play_env.reset() state = play_env.normalise_state(state) step = 0 ep_done = False while not ep_done: frame = play_env.render() if play_params.RECORD_DIR is not None: filepath = play_params.RECORD_DIR + '/Ep%03d_Step%04d.jpg' % ( ep, step) cv2.imwrite(filepath, frame) action = actor_net(np.expand_dims(state.astype(np.float32), 0))[0] state, _, terminal = play_env.step(action) state = play_env.normalise_state(state) step += 1 # Episode can finish either by reaching terminal state or max episode steps if terminal or step == play_params.MAX_EP_LENGTH: ep_done = True # Convert saved frames to gif exit() if play_params.RECORD_DIR is not None: images = [] for file in tqdm(sorted(os.listdir(play_params.RECORD_DIR)), desc='converting to gif'): # Load image filename = play_params.RECORD_DIR + '/' + file im = cv2.imread(filename) images.append(im) # Delete static image once loaded os.remove(filename) # Save as gif print("Saving to ", play_params.RECORD_DIR) imageio.mimsave(play_params.RECORD_DIR + '/%s.gif' % play_params.ENV, images[:-1], duration=0.01) play_env.close()
class Learner: def __init__(self, PER_memory, run_agent_event, stop_agent_event): self.PER_memory = PER_memory self.run_agent_event = run_agent_event self.stop_agent_event = stop_agent_event if train_params.ENV == 'Pendulum-v0': self.eval_env = PendulumWrapper() elif train_params.ENV == 'LunarLanderContinuous-v2': self.eval_env = LunarLanderContinuousWrapper() elif train_params.ENV == 'BipedalWalker-v2': self.eval_env = BipedalWalkerWrapper() elif train_params.ENV == 'BipedalWalkerHardcore-v2': self.eval_env = BipedalWalkerWrapper(hardcore=True) else: raise Exception('Chosen environment does not have an environment wrapper defined. Please choose an environment with an environment wrapper defined, or create a wrapper for this environment in utils.env_wrapper.py') self.summary_writer = tf.summary.create_file_writer(train_params.LOG_DIR + '/eval/') def build_network(self): # Create value (critic) network + target network if train_params.USE_BATCH_NORM: pass # for now # self.critic_net = Critic_BN(self.state_ph, self.action_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, train_params.NUM_ATOMS, train_params.V_MIN, train_params.V_MAX, is_training=True, scope='learner_critic_main') # self.critic_target_net = Critic_BN(self.state_ph, self.action_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, train_params.NUM_ATOMS, train_params.V_MIN, train_params.V_MAX, is_training=True, scope='learner_critic_target') else: self.critic_net = Critic(train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, train_params.NUM_ATOMS, train_params.V_MIN, train_params.V_MAX, name='critic') self.critic_target_net = Critic(train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, train_params.NUM_ATOMS, train_params.V_MIN, train_params.V_MAX, name='critic_target') # Create policy (actor) network + target network if train_params.USE_BATCH_NORM: pass # for now # self.actor_net = Actor_BN(self.state_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, is_training=True, scope='learner_actor_main') # self.actor_target_net = Actor_BN(self.state_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, is_training=True, scope='learner_actor_target') else: self.actor_net = Actor(train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, name='actor') self.actor_target_net = Actor(train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, name='actor_target') def target_network_update(self, tau): network_params = self.actor_net.trainable_variables + self.critic_net.trainable_variables target_network_params = self.actor_target_net.trainable_variables + self.critic_target_net.trainable_variables for from_var,to_var in zip(network_params, target_network_params): to_var.assign((tf.multiply(from_var, tau) + tf.multiply(to_var, 1. - tau))) def initialise_vars(self): # Load ckpt file if given, otherwise initialise variables and hard copy to target networks if train_params.INITIAL_ACTOR_MODEL is not None: self.actor_net.load_weights(train_params.INITIAL_ACTOR_MODEL) self.critic_net.load_weights(train_params.INITIAL_CRITIC_MODEL) else: self.start_step = 0 # Perform hard copy (tau=1.0) of initial params to target networks self.target_network_update(1.0) def run(self): # Sample batches of experiences from replay memory and train learner networks # Initialise beta to start value priority_beta = train_params.PRIORITY_BETA_START beta_increment = (train_params.PRIORITY_BETA_END - train_params.PRIORITY_BETA_START) / train_params.NUM_STEPS_TRAIN avg_return = compute_avg_return(self.eval_env, self.actor_net, train_params.MAX_EP_LENGTH) scalar_summary(self.summary_writer, "Average Return", avg_return, step=1) # Can only train when we have at least batch_size num of samples in replay memory while len(self.PER_memory) <= train_params.BATCH_SIZE: sys.stdout.write('\rPopulating replay memory up to batch_size samples...') sys.stdout.flush() t = trange(self.start_step+1, train_params.NUM_STEPS_TRAIN+1, desc='[Train]') for train_step in t: # Get minibatch minibatch = self.PER_memory.sample(train_params.BATCH_SIZE, priority_beta) states_batch = minibatch[0].astype(np.float32) actions_batch = minibatch[1].astype(np.float32) rewards_batch = minibatch[2].astype(np.float32) next_states_batch = minibatch[3].astype(np.float32) terminals_batch = minibatch[4] gammas_batch = minibatch[5].astype(np.float32) weights_batch = minibatch[6].astype(np.float32) idx_batch = minibatch[7] # ================================================================== # Critic training step # ================================================================== # Predict actions for next states by passing next states through policy target network future_action = self.actor_target_net(next_states_batch) # Predict future Z distribution by passing next states and actions through value target network, also get target network's Z-atom values _, target_Z_dist = self.critic_target_net(next_states_batch, future_action) target_Z_atoms = self.critic_target_net.z_atoms # Create batch of target network's Z-atoms target_Z_atoms = np.repeat(np.expand_dims(target_Z_atoms, axis=0), train_params.BATCH_SIZE, axis=0) # Value of terminal states is 0 by definition target_Z_atoms[terminals_batch, :] = 0.0 # Apply Bellman update to each atom target_Z_atoms = np.expand_dims(rewards_batch, axis=1) + (target_Z_atoms*np.expand_dims(gammas_batch, axis=1)) # Train critic td_error, total_loss = self.critic_net.train(states_batch, actions_batch, target_Z_atoms, target_Z_dist, weights_batch) # Use critic TD errors to update sample priorities # self.PER_memory.update_priorities(idx_batch, (np.abs(td_error.eval(session=tf.compat.v1.Session()))+train_params.PRIORITY_EPSILON)) self.PER_memory.update_priorities(idx_batch, (np.abs(td_error.numpy())+train_params.PRIORITY_EPSILON)) # ================================================================== # Actor training step # ================================================================== # Get policy network's action outputs for selected states actor_actions = self.actor_net(states_batch) action_grads = self.critic_net.get_action_grads(states_batch, actor_actions) # Train actor self.actor_net.train(states_batch, action_grads) # Update target networks self.target_network_update(train_params.TAU) actor_actions = self.actor_net(states_batch) # Increment beta value at end of every step priority_beta += beta_increment # Periodically check capacity of replay mem and remove samples (by FIFO process) above this capacity if train_step % train_params.REPLAY_MEM_REMOVE_STEP == 0: if len(self.PER_memory) > train_params.REPLAY_MEM_SIZE: # Prevent agent from adding new experiences to replay memory while learner removes samples self.run_agent_event.clear() samples_to_remove = len(self.PER_memory) - train_params.REPLAY_MEM_SIZE self.PER_memory.remove(samples_to_remove) # Allow agent to continue adding experiences to replay memory self.run_agent_event.set() if train_step % train_params.PRINTOUT_STEP == 0: t.set_description('[Train] loss={0:.4f}, avg_return={1:.2f}'.format(total_loss, avg_return)) if train_step % train_params.EVALUATE_SAVE_MODEL_STEP == 0: self.actor_net.save_weights(train_params.LOG_DIR + '/eval/actor_%d' % train_step) self.critic_net.save_weights(train_params.LOG_DIR + '/eval/critic_%d' % train_step) avg_return = compute_avg_return(self.eval_env, self.actor_net, train_params.MAX_EP_LENGTH) scalar_summary(self.summary_writer, "Average Return", avg_return, step=train_step) # Stop the agents self.stop_agent_event.set()