class DQN(): def __init__(self, env_type, state_dims, num_actions): if env_type == EnvTypes.ATARI: state_size = [state_dims[0], state_dims[1]*FRAME_STACK, state_dims[2]] elif env_type == EnvTypes.STANDARD: state_size = state_dims self.replay_memory = ReplayMemory(REPLAY_MEMORY_CAPACITY, state_size) self.exploration = 1.0 self.train_iter = 0 self.env_type = env_type if env_type == EnvTypes.ATARI: buffer_size = FRAME_STACK*FRAME_SKIP self.observation_buffer = [np.zeros((state_dims[0], state_dims[1], state_dims[2])) for _ in range(buffer_size)] else: self.observation_buffer = [np.zeros((state_dims[0]))] self.config = tf.ConfigProto() self.config.gpu_options.per_process_gpu_memory_fraction = GPU_MEMORY_FRACTION self.sess = tf.Session(config=self.config) # build q network self.dqn_vars = dict() with tf.variable_scope(DQN_SCOPE): if env_type == EnvTypes.ATARI: self.x, self.initial_layers = self.add_atari_layers(state_dims, self.dqn_vars) elif env_type == EnvTypes.STANDARD: self.x, self.initial_layers = self.add_standard_layers(state_dims, self.dqn_vars) # add final hidden layers self.hid = fc(self.initial_layers, 128, HIDDEN, var_dict=self.dqn_vars) self.q = fc(self.hid, num_actions, OUTPUT, var_dict=self.dqn_vars, activation=False) tf.histogram_summary('q_values', self.q) # build target network self.target_vars = dict() with tf.variable_scope(TARGET_SCOPE): if env_type == EnvTypes.ATARI: self.t_x, self.t_initial_layers = self.add_atari_layers(state_dims, self.target_vars) elif env_type == EnvTypes.STANDARD: self.t_x, self.t_initial_layers = self.add_standard_layers(state_dims, self.target_vars) self.t_hid = fc(self.t_initial_layers, 128, HIDDEN, var_dict=self.target_vars) self.t_q = fc(self.t_hid, num_actions, OUTPUT, var_dict=self.target_vars, activation=False) tf.histogram_summary('target_q_values', self.t_q) # add weight transfer operations from primary dqn network to target network self.assign_ops = [] with tf.variable_scope(TRANSFER_SCOPE): for variable in self.dqn_vars.keys(): target_variable = TARGET_SCOPE + variable[len(DQN_SCOPE):] decay = tf.mul(1 - TAU, self.target_vars[target_variable]) update = tf.mul(TAU, self.dqn_vars[variable]) new_target_weight = tf.add(decay, update) target_assign = self.target_vars[target_variable].assign(new_target_weight) self.assign_ops.append(target_assign) # build dqn evaluation with tf.variable_scope(EVALUATION_SCOPE): # one-hot action selection self.action = tf.placeholder(tf.int32, shape=[None]) self.action_one_hot = tf.one_hot(self.action, num_actions) # reward self.reward = tf.placeholder(tf.float32, shape=[None, 1]) # terminal state self.nonterminal = tf.placeholder(tf.float32, shape=[None, 1]) self.target = tf.add(self.reward, tf.mul(GAMMA, tf.mul(self.nonterminal, tf.reduce_max(self.t_q, 1, True)))) self.predict = tf.reduce_sum(tf.mul(self.action_one_hot, self.q), 1, True) self.error = tf.reduce_mean(mse(self.predict, self.target)) tf.scalar_summary('error', self.error) val_print = tf.Print(self.error, [self.predict, self.target]) self.optimize = tf.train.RMSPropOptimizer(ALPHA, decay=RMS_DECAY, momentum=MOMENTUM, epsilon=EPSILON).minimize(self.error, var_list=self.dqn_vars.values()) # write out the graph and summaries for tensorboard self.summaries = tf.merge_all_summaries() if os.path.isdir(TENSORBOARD_GRAPH_DIR): shutil.rmtree(TENSORBOARD_GRAPH_DIR) self.writer = tf.train.SummaryWriter(TENSORBOARD_GRAPH_DIR, self.sess.graph) # initialize variables self.sess.run(tf.initialize_all_variables()) # create saver self.saver = tf.train.Saver() def add_atari_layers(self, dims, var_dict): x = tf.placeholder(tf.float32, shape=[None, dims[0], dims[1]*FRAME_STACK, 1]) conv1 = conv2d(x, 8, 4, 32, CONV1, var_dict=var_dict) conv2 = conv2d(conv1, 4, 2, 64, CONV2, var_dict=var_dict) conv3 = conv2d(conv2, 3, 1, 64, CONV3, var_dict=var_dict) conv_shape = conv3.get_shape().as_list() flatten = [-1, conv_shape[1]*conv_shape[2]*conv_shape[3]] return x, tf.reshape(conv3, flatten) def add_standard_layers(self, dims, var_dict): x = tf.placeholder(tf.float32, shape=[None, dims[0]]) fc1 = fc(x, 256, FC, var_dict=var_dict) return x, fc1 def process_observation(self, observation): if self.env_type == EnvTypes.ATARI: # convert to normalized luminance and downscale observation = downscale(rgb_to_luminance(observation), 2) # push the new observation onto the buffer self.observation_buffer.pop(len(self.observation_buffer)-1) self.observation_buffer.insert(0, observation) def _get_stacked_state(self): stacked_state = self.observation_buffer[0] for i in range(1, FRAME_STACK): stacked_state = np.hstack((stacked_state, self.observation_buffer[i*FRAME_SKIP])) return stacked_state def _predict(self): if self.env_type == EnvTypes.ATARI: state = self._get_stacked_state() else: state = self.observation_buffer[0] state = np.expand_dims(state, axis=0) return np.argmax(self.sess.run(self.q, feed_dict={self.x: state})) def training_predict(self, env, observation): self.process_observation(observation) # select action according to epsilon-greedy policy if random.random() < self.exploration: action = env.action_space.sample() else: action = self._predict() self.exploration = max(self.exploration - EXPLORATION_DECAY, FINAL_EXPLORATION) return action def testing_predict(self, observation): self.process_observation(observation) return self._predict() def notify_state_transition(self, action, reward, done): if self.env_type == EnvTypes.ATARI: state = self._get_stacked_state() else: state = self.observation_buffer[0] self.replay_memory.add_state_transition(state, action, reward, done) if done: # flush the observation buffer for i in range(len(self.observation_buffer)): self.observation_buffer[i] = np.zeros(self.observation_buffer[i].shape) def batch_train(self, save_dir): # sample batch from replay memory state, action, reward, terminal, newstate = self.replay_memory.sample(BATCH_SIZE) reward = np.expand_dims(reward, axis=1) terminal = np.expand_dims(terminal, axis=1) nonterminal = 1 - terminal # update target network weights self.sess.run(self.assign_ops) # run neural network training step if self.train_iter % SUMMARY_PERIOD == 0: summary, _ = self.sess.run([self.summaries, self.optimize], feed_dict={self.x:state, self.t_x:newstate, self.action:action, self.reward:reward, self.nonterminal:nonterminal}) self.writer.add_summary(summary, self.train_iter) else: self.sess.run(self.optimize, feed_dict={self.x:state, self.t_x:newstate, self.action:action, self.reward:reward, self.nonterminal:nonterminal}) # save the dqn if save_dir is not None and self.train_iter % SAVE_CHECKPOINT_PERIOD == 0: self.save_algorithm(save_dir) self.train_iter += 1 def save_algorithm(self, save_dir): # create directory tree for saving the algorithm checkpoint_dir = save_dir + "/save_{}".format(self.train_iter) os.mkdir(checkpoint_dir) model_file = checkpoint_dir + "/model.ckpt" print("Saving algorithm to {}".format(checkpoint_dir)) t = time.time() self.saver.save(self.sess, model_file) print("Completed saving in {} seconds".format(time.time() - t)) def restore_algorithm(self, restore_dir): self.train_iter = int(restore_dir[restore_dir.rfind("save_") + len("save_"):]) self.saver.restore(self.sess, restore_dir + "/model.ckpt")