def fit(self): # Initialize a new game and store the screens in the self.history #screen, reward, is_done = self.game.new_random_game() if self.params.env == 'atari': screen, reward, is_done = self.game.new_random_game() else: screen, reward, is_done = self.game.new_game() for _ in range(self.params.history_length): self.history.add(screen) # Initialize the TensorFlow session gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=self.params.gpu_memory) with tf.Session(config=tf.ConfigProto( gpu_options=gpu_options)) as sess: # Initialize the TensorFlow session init = tf.global_variables_initializer() sess.run(init) # Only save trainable variables and the global iteration to disk tf_vars_to_save = tf.trainable_variables() + [ self.dqn_train.global_iteration ] saver = tf.train.Saver(tf_vars_to_save, max_to_keep=200) if self.params.model_file is not None: # Load pre-trained model from disk model_path = os.path.join(self.checkpoint_dir, self.params.model_file) saver.restore(sess, model_path) self.train_iteration, learning_rate = sess.run([ self.dqn_train.global_iteration, self.dqn_train.learning_rate ]) print( "Restarted training from model file. iteration = %06i, Learning Rate = %.5f" % (self.train_iteration, learning_rate)) # Initialize summary writer self.dqn_train.build_summary_writer(sess) # Initialize the target Q-Network fixed with the same weights update_target_network(sess, "qnetwork-train", "qnetwork-target") for iteration in range(self.params.num_iterations): self._sel_move(sess, iteration) self._train(sess, iteration, saver) print("Finished training Q-network.")
def train(params): # Load Atari rom and prepare ALE environment atari = GymEnvironment(params.random_start_wait, params.show_game) # Initialize two Q-Value Networks one for training and one for target prediction dqn_train = DeepQNetwork( params=params, num_actions=atari.num_actions, network_name="qnetwork-train", trainable=True ) # Q-Network for predicting target Q-values dqn_target= DeepQNetwork( params=params, num_actions=atari.num_actions, network_name="qnetwork-target", trainable=False ) # Initialize replay memory for storing experience to sample batches from replay_mem = ReplayMemory(params.replay_capacity, params.batch_size) # Small structure for storing the last four screens history = ScreenHistory(params) # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it replay_mem_dump = os.path.abspath(os.path.join(params.output_dir, "replay_memory.hdf5")) checkpoint_dir = os.path.abspath(os.path.join(params.output_dir, "checkpoints")) checkpoint_prefix = os.path.join(checkpoint_dir, "model") if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) train_step = 0 count_actions = np.zeros(atari.num_actions) # Count per action (only greedy) count_act_random = 0 # Count of random actions count_act_greedy = 0 # Count of greedy actions # Histories of qvalues and loss for running average qvalues_hist = collections.deque([0]*params.interval_summary, maxlen=params.interval_summary) loss_hist = collections.deque([10]*params.interval_summary, maxlen=params.interval_summary) # Time measurements dt_batch_gen = collections.deque([0]*10, maxlen=10) dt_optimization = collections.deque([0]*10, maxlen=10) dt_train_total = collections.deque([0]*10, maxlen=10) # Optionally load pre-initialized replay memory from disk if params.replay_mem_dump is not None and params.is_train: print("Loading pre-initialized replay memory from HDF5 file.") replay_mem.load(params.replay_mem_dump) # Initialize a new game and store the screens in the history reward, screen, is_terminal = atari.new_random_game() for _ in xrange(params.history_length): history.add(screen) # Initialize the TensorFlow session gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=0.4 ) with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: # Initialize the TensorFlow session init = tf.initialize_all_variables() sess.run(init) # Only save trainable variables and the global step to disk tf_vars_to_save = tf.trainable_variables() + [dqn_train.global_step] saver = tf.train.Saver(tf_vars_to_save, max_to_keep=40) if params.model_file is not None: # Load pre-trained model from disk saver.restore(sess, params.model_file) train_step, learning_rate = sess.run([dqn_train.global_step, dqn_train.learning_rate]) print("Restarted training from model file. Step = %06i, Learning Rate = %.5f" % (train_step, learning_rate)) # Initialize summary writer dqn_train.build_summary_writer(sess) # Initialize the target Q-Network fixed with the same weights update_target_network(sess, "qnetwork-train", "qnetwork-target") for step in xrange(params.num_steps): replay_mem_size = replay_mem.num_examples() if params.is_train and replay_mem_size < params.train_start and step % 1000 == 0: print("Initializing replay memory %i/%i" % (step, params.train_start)) # Epsilon Greedy Exploration: with the probability of epsilon # choose a random action, otherwise go greedy with the action # having the maximal Q-value. Note the minimum episolon of 0.1 if params.is_train: epsilon = max(0.1, 1.0-float(train_step*params.train_freq) / float(params.epsilon_step)) else: epsilon = 0.05 ################################################################ ####################### SELECT A MOVE ########################## ################################################################ # Either choose a random action or predict the action using the Q-network do_random_action = (random.random() < epsilon) if do_random_action or (replay_mem_size < params.train_start and params.is_train): action_id = random.randrange(atari.num_actions) count_act_random += 1 else: # Get the last screens from the history and perform # feed-forward through the network to compute Q-values feed_dict = { dqn_train.pl_screens: history.get() } qvalues = sess.run(dqn_train.qvalues, feed_dict=feed_dict) # Choose the best action based on the approximated Q-values qvalue_max = np.max(qvalues[0]) action_id = np.argmax(qvalues[0]) count_act_greedy += 1 count_actions[action_id] += 1 qvalues_hist.append(qvalue_max) ################################################################ ####################### PLAY THE MOVE ########################## ################################################################ # Play the selected action (either random or predicted) on the Atari game # Note that the action is performed for k = 4 frames (frame skipping) cumulative_reward, screen, is_terminal = atari.act(action_id) # Perform reward clipping and add the example to the replay memory cumulative_reward = min(+1.0, max(-1.0, cumulative_reward)) # Add the screen to short term history and replay memory history.add(screen) # Add experience to replay memory if params.is_train: replay_mem.add(action_id, cumulative_reward, screen, is_terminal) # Check if we are game over, and if yes, initialize a new game if is_terminal: reward, screen, is_terminal = atari.new_random_game() replay_mem.add(0, reward, screen, is_terminal) history.add(screen) ################################################################ ###################### TRAINING MODEL ########################## ################################################################ if params.is_train and step > params.train_start and step % params.train_freq == 0: t1 = time.time() # Prepare batch and train the network # TODO: set actions with terminal == 1 to reward = -1 ?? screens_in, actions, rewards, screens_out, terminals = replay_mem.sample_batch() dt_batch_gen.append(time.time() - t1) t2 = time.time() # Compute the target rewards from the previously fixed network # Note that the forward run is performed on the output screens. qvalues_target = sess.run( dqn_target.qvalues, feed_dict={ dqn_target.pl_screens: screens_out } ) # Inputs for trainable Q-network feed_dict = { dqn_train.pl_screens : screens_in, dqn_train.pl_actions : actions, dqn_train.pl_rewards : rewards, dqn_train.pl_terminals : terminals, dqn_train.pl_qtargets : np.max(qvalues_target, axis=1), } # Actual training operation _, loss, train_step = sess.run([dqn_train.train_op, dqn_train.loss, dqn_train.global_step], feed_dict=feed_dict) t3 = time.time() dt_optimization.append(t3 - t2) dt_train_total.append(t3 - t1) # Running average of the loss loss_hist.append(loss) # Check if the returned loss is not NaN if np.isnan(loss): print("[%s] Training failed with loss = NaN." % datetime.now().strftime("%Y-%m-%d %H:%M")) # Once every n = 10000 frames update the Q-network for predicting targets if train_step % params.network_update_rate == 0: print("[%s] Updating target network." % datetime.now().strftime("%Y-%m-%d %H:%M")) update_target_network(sess, "qnetwork-train", "qnetwork-target") ################################################################ ####################### MODEL EVALUATION ####################### ################################################################ if params.is_train and train_step % params.eval_frequency == 0: eval_total_reward = 0 eval_num_episodes = 0 eval_num_rewards = 0 eval_episode_max_reward = 0 eval_episode_reward = 0 eval_actions = np.zeros(atari.num_actions) # Initialize new game without random start moves reward, screen, terminal = atari.new_game() for _ in range(4): history.add(screen) for eval_step in range(params.eval_steps): if random.random() < params.eval_epsilon: # Random action action_id = random.randrange(atari.num_actions) else: # Greedy action # Get the last screens from the history and perform # feed-forward through the network to compute Q-values feed_dict_eval = { dqn_train.pl_screens: history.get() } qvalues = sess.run(dqn_train.qvalues, feed_dict=feed_dict_eval) # Choose the best action based on the approximated Q-values qvalue_max = np.max(qvalues[0]) action_id = np.argmax(qvalues[0]) # Keep track of how many of each action is performed eval_actions[action_id] += 1 # Perform the action reward, screen, terminal = atari.act(action_id) history.add(screen) eval_episode_reward += reward if reward > 0: eval_num_rewards += 1 if terminal: eval_total_reward += eval_episode_reward eval_episode_max_reward = max(eval_episode_reward, eval_episode_max_reward) eval_episode_reward = 0 eval_num_episodes += 1 reward, screen, terminal = atari.new_game() for _ in range(4): history.add(screen) # Send statistics about the environment to TensorBoard eval_update_ops = [ dqn_train.eval_rewards.assign(eval_total_reward), dqn_train.eval_num_rewards.assign(eval_num_rewards), dqn_train.eval_max_reward.assign(eval_episode_max_reward), dqn_train.eval_num_episodes.assign(eval_num_episodes), dqn_train.eval_actions.assign(eval_actions / np.sum(eval_actions)) ] sess.run(eval_update_ops) summaries = sess.run(dqn_train.eval_summary_op, feed_dict=feed_dict) dqn_train.train_summary_writer.add_summary(summaries, train_step) print("[%s] Evaluation Summary" % datetime.now().strftime("%Y-%m-%d %H:%M")) print(" Total Reward: %i" % eval_total_reward) print(" Max Reward per Episode: %i" % eval_episode_max_reward) print(" Num Episodes: %i" % eval_num_episodes) print(" Num Rewards: %i" % eval_num_rewards) ################################################################ ###################### PRINTING / SAVING ####################### ################################################################ # Write a training summary to disk if params.is_train and train_step % params.interval_summary == 0: avg_dt_batch_gen = sum(dt_batch_gen) / float(len(dt_batch_gen)) avg_dt_optimization = sum(dt_optimization) / float(len(dt_optimization)) avg_dt_total = sum(dt_train_total) / float(len(dt_train_total)) # print("Avg. Time Batch Preparation: %.3f seconds" % avg_dt_batch_gen) # print("Avg. Time Train Operation: %.3f seconds" % avg_dt_train_op) # print("Avg. Time Total per Batch: %.3f seconds (%.2f samples/second)" % # (avg_dt_total, (1.0/avg_dt_total)*params.batch_size)) # Send statistics about the environment to TensorBoard update_game_stats_ops = [ dqn_train.avg_reward_per_game.assign(atari.avg_reward_per_episode()), dqn_train.max_reward_per_game.assign(atari.max_reward_per_episode), dqn_train.avg_moves_per_game.assign(atari.avg_steps_per_episode()), dqn_train.total_reward_replay.assign(replay_mem.total_reward()), dqn_train.num_games_played.assign(atari.episode_number), dqn_train.actions_random.assign(count_act_random), dqn_train.actions_greedy.assign(count_act_greedy), dqn_train.runtime_batch.assign(avg_dt_batch_gen), dqn_train.runtime_train.assign(avg_dt_optimization), dqn_train.runtime_total.assign(avg_dt_total), dqn_train.samples_per_second.assign((1.0/avg_dt_total)*params.batch_size) ] sess.run(update_game_stats_ops) # Build and save summaries summaries = sess.run(dqn_train.train_summary_op, feed_dict=feed_dict) dqn_train.train_summary_writer.add_summary(summaries, train_step) avg_qvalue = avg_loss = 0 for i in xrange(len(qvalues_hist)): avg_qvalue += qvalues_hist[i] avg_loss += loss_hist[i] avg_qvalue /= float(len(qvalues_hist)) avg_loss /= float(len(loss_hist)) format_str = "[%s] Step %06i, ReplayMemory = %i, Epsilon = %.4f, "\ "Episodes = %i, Avg.Reward = %.2f, Max.Reward = %.2f, Avg.QValue = %.4f, Avg.Loss = %.6f" print(format_str % (datetime.now().strftime("%Y-%m-%d %H:%M"), train_step, replay_mem.num_examples(), epsilon, atari.episode_number, atari.avg_reward_per_episode(), atari.max_reward_per_episode, avg_qvalue, avg_loss)) # For debugging purposes, dump the batch to disk #print("[%s] Writing batch images to file (debugging)" % # datetime.now().strftime("%Y-%m-%d %H:%M")) #batch_output_dir = os.path.join(params.output_dir, "batches/%06i/" % train_step) #replay_mem.write_batch_to_disk(batch_output_dir, screens_in, actions, rewards, screens_out) # Write model checkpoint to disk if params.is_train and train_step % params.interval_checkpoint == 0: path = saver.save(sess, checkpoint_prefix, global_step=train_step) print("[%s] Saving TensorFlow model checkpoint to disk." % datetime.now().strftime("%Y-%m-%d %H:%M")) # Dump the replay memory to disk # TODO: fix this! # print("[%s] Saving replay memory to disk." % # datetime.now().strftime("%Y-%m-%d %H:%M")) # replay_mem.save(replay_mem_dump) sum_actions = float(reduce(lambda x, y: x+y, count_actions)) action_str = "" for action_id, action_count in enumerate(count_actions): action_perc = action_count/sum_actions if not sum_actions == 0 else 0 action_str += "<%i, %s, %i, %.2f> " % \ (action_id, atari.action_to_string(action_id), action_count, action_perc) format_str = "[%s] Q-Network Actions Summary: NumRandom: %i, NumGreedy: %i, %s" print(format_str % (datetime.now().strftime("%Y-%m-%d %H:%M"), count_act_random, count_act_greedy, action_str)) print("Finished training Q-network.")
def _train(self, sess, iteration, saver): ################################################################ ###################### TRAINING MODEL ########################## ################################################################ if self.params.is_train and iteration > self.params.train_start and iteration % self.params.train_freq == 0: screens, actions, rewards, screens_1, dones = self.replay_mem.sample_batch( ) # Below, we perform the Double-DQN update. # First, we need to determine the best actions # in the train network qvalues_train = sess.run( self.dqn_train.qvalues, feed_dict={self.dqn_train.pl_screens: screens_1}) # Find the best actions for each using the train network # which will be used with the q-values form the target network actions_target = np.argmax(qvalues_train, 1) # We use this to evalute the q-value for some state # Now,we get the q-values for all actions given the states # We then later sort out the q-values from the target network # using the best actions from the train network qvalues_target = sess.run( self.dqn_target.qvalues, feed_dict={self.dqn_target.pl_screens: screens_1}) # Inputs for trainable Q-network feed_dict = { self.dqn_train.pl_screens: screens, self.dqn_train.pl_actions: actions, self.dqn_train.pl_rewards: rewards, self.dqn_train.pl_dones: dones, #self.dqn_train.pl_qtargets : np.max(qvalues_target, axis=1), self.dqn_train.pl_qtargets: qvalues_target, self.dqn_train.pl_actions_target: actions_target, } # Actual training operation _, loss, self.train_iteration = sess.run([ self.dqn_train.train_op, self.dqn_train.loss, self.dqn_train.global_iteration ], feed_dict=feed_dict) # Running average of the loss self.loss_hist.append(loss) # Check if the returned loss is not NaN if np.isnan(loss): print("[%s] Training failed with loss = NaN." % datetime.now().strftime("%Y-%m-%d %H:%M")) # Once every n = 10000 frames update the Q-network for predicting targets if self.train_iteration % self.params.network_update_rate == 0: print("[%s] Updating target network." % datetime.now().strftime("%Y-%m-%d %H:%M")) update_target_network(sess, "qnetwork-train", "qnetwork-target") self._evaluate(sess, feed_dict) self._print_save(sess, feed_dict, saver)