def _step(self) -> Dict[str, tf.Tensor]: """Do a step of SGD and update the priorities.""" # Pull out the data needed for updates/priorities. inputs = next(self._iterator) transitions: types.Transition = inputs.data keys, probs = inputs.info[:2] with tf.GradientTape() as tape: # Evaluate our networks. q_tm1 = self._network(transitions.observation) q_t_value = self._target_network(transitions.next_observation) q_t_selector = self._network(transitions.next_observation) # The rewards and discounts have to have the same type as network values. r_t = tf.cast(transitions.reward, q_tm1.dtype) if self._max_abs_reward: r_t = tf.clip_by_value(r_t, -self._max_abs_reward, self._max_abs_reward) d_t = tf.cast(transitions.discount, q_tm1.dtype) * tf.cast( self._discount, q_tm1.dtype) # Compute the loss. _, extra = trfl.double_qlearning(q_tm1, transitions.action, r_t, d_t, q_t_value, q_t_selector) loss = losses.huber(extra.td_error, self._huber_loss_parameter) # Get the importance weights. importance_weights = 1. / probs # [B] importance_weights **= self._importance_sampling_exponent importance_weights /= tf.reduce_max(importance_weights) # Reweight. loss *= tf.cast(importance_weights, loss.dtype) # [B] loss = tf.reduce_mean(loss, axis=[0]) # [] # Do a step of SGD. gradients = tape.gradient(loss, self._network.trainable_variables) gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) self._optimizer.apply(gradients, self._network.trainable_variables) # Get the priorities that we'll use to update. priorities = tf.abs(extra.td_error) # Periodically update the target network. if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(self._network.variables, self._target_network.variables): dest.assign(src) self._num_steps.assign_add(1) # Report loss & statistics for logging. fetches = { 'loss': loss, 'keys': keys, 'priorities': priorities, } return fetches
def _build_model_for_training(self): inputs = self.create_inputs("main", **self.model_kwargs) model = self.create_model(inputs, **self.model_kwargs) model_vars = model.trainable_weights q = model.output with tf.name_scope('training'): # Input placeholders actions = tf.placeholder(tf.int32, (None, ), name="action") rewards = tf.placeholder(tf.float32, (None, ), name="reward") inputs_next = self.create_inputs("next", **self.model_kwargs) terminates = tf.placeholder(tf.bool, (None, ), name="terminate") # Target network target_model = self.create_model(inputs_next, **self.model_kwargs) target_vars = target_model.trainable_weights q_next = tf.stop_gradient(target_model.output) q_next_online_net = tf.stop_gradient(model(inputs_next)) # Loss pcontinues = (1.0 - tf.to_float(terminates)) * self.gamma errors, _info = double_qlearning(q, actions, rewards, pcontinues, q_next, q_next_online_net) td_error = _info.td_error loss = K.mean(errors) optimizer = tf.train.AdamOptimizer( learning_rate=self.learning_rate) optimize_expr = optimizer.minimize(loss, var_list=model_vars) with tf.control_dependencies([optimize_expr]): optimize_expr = tf.group( *[tf.assign(*a) for a in model.updates]) # update_target_fn will be called periodically to copy Q network to target Q network update_target_expr = tf.group(*[ var_target.assign(var) for var, var_target in zip(model_vars, target_vars) ]) # Create callable functions train_fn = K.function(inputs + [ actions, rewards, terminates, ] + inputs_next, outputs=[td_error], updates=[optimize_expr]) act_fn = K.function(inputs=inputs, outputs=[K.argmax(q, axis=1)]) q_fn = K.function(inputs=inputs, outputs=[q]) update_fn = K.function([], [], updates=[update_target_expr]) self._update_parameters = lambda: update_fn([]) self._train = train_fn self._act = lambda x: act_fn([x])[0] self._q = lambda x: q_fn([x])[0] return model
def _step(self) -> Dict[str, tf.Tensor]: """Do a step of SGD and update the priorities.""" # Pull out the data needed for updates/priorities. inputs = next(self._iterator) o_tm1, a_tm1, r_t, d_t, o_t = inputs.data keys, probs = inputs.info[:2] with tf.GradientTape() as tape: # Evaluate our networks. q_tm1 = self._network(o_tm1) q_t_value = self._target_network(o_t) q_t_selector = self._network(o_t) # The rewards and discounts have to have the same type as network values. r_t = tf.cast(r_t, q_tm1.dtype) r_t = tf.clip_by_value(r_t, -1., 1.) d_t = tf.cast(d_t, q_tm1.dtype) * tf.cast(self._discount, q_tm1.dtype) # Compute the loss. _, extra = trfl.double_qlearning(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector) loss = losses.huber(extra.td_error, self._huber_loss_parameter) # Get the importance weights. importance_weights = 1. / probs # [B] importance_weights **= self._importance_sampling_exponent importance_weights /= tf.reduce_max(importance_weights) # Reweight. loss *= tf.cast(importance_weights, loss.dtype) # [B] loss = tf.reduce_mean(loss, axis=[0]) # [] # Do a step of SGD. gradients = tape.gradient(loss, self._network.trainable_variables) self._optimizer.apply(gradients, self._network.trainable_variables) # Update the priorities in the replay buffer. if self._replay_client: priorities = tf.cast(tf.abs(extra.td_error), tf.float64) self._replay_client.update_priorities( table=adders.DEFAULT_PRIORITY_TABLE, keys=keys, priorities=priorities) # Periodically update the target network. if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(self._network.variables, self._target_network.variables): dest.assign(src) self._num_steps.assign_add(1) # Report loss & statistics for logging. fetches = { 'loss': loss, } return fetches
def __init__(self, name, learning_rate=0.01, state_size=4, action_size=2, hidden_size=10, batch_size=20): with tf.variable_scope(name): self._inputs = tf.placeholder(tf.float32, [None, state_size], name='inputs') self._actions = tf.placeholder(tf.int32, [batch_size], name='actions') self.fc1 = tf.contrib.layers.fully_connected( self._inputs, hidden_size) self.fc2 = tf.contrib.layers.fully_connected(self.fc1, hidden_size) self.fc3 = tf.contrib.layers.fully_connected(self.fc2, hidden_size) self.fc4 = tf.contrib.layers.fully_connected(self.fc3, hidden_size) self.output = tf.contrib.layers.fully_connected(self.fc4, action_size, activation_fn=None) self.name = name self._targetQs = tf.placeholder(tf.float32, [batch_size, action_size], name='target') self.reward = tf.placeholder(tf.float32, [batch_size], name='reward') self.discount = tf.constant(0.99, shape=[batch_size], dtype=tf.float32, name='discount') q_loss, q_learning = trfl.double_qlearning( self.output, self._actions, self.reward, self.discount, self._targetQs, self.output) self.loss = tf.reduce_mean(q_loss) self.opt = tf.train.AdamOptimizer(learning_rate).minimize( self.loss)
def q_learning(vision_model_dict, agent_model_dict, target_agent_model_dict, inputs, batch_size, kp_type, agent_size, mask_threshold, patch_sizes, kpt_encoder_type, mp_steps, img_size, lsp_layers, window_size, gamma, double_q, n_step_q): """ :param vision_model_dict: :param agent_model_dict: :param target_agent_model_dict: :param inputs: bottom_up_kpt inputs [batch, T, dims] :param batch_size: (int) :param kp_type: (str) "transporter" or "permakey" type of keypoint used for bottom-up processing :param agent_size: (int) size of agent lstm :param mask_threshold: (float) :param patch_sizes: (int) size of patch size for "permakey" keypoints :param kpt_encoder_type: (str) "cnn" for conv-net "gnn" for graph-net :param mp_steps: (int) number of message-passing steps in GNNs :param img_size: (int) size of input image (H for H x H img) :param lsp_layers: (tuple) of layers for "permakey" keypoints :param window_size: (int) size of window used for recurrent q-learning :param gamma: (float) discount factor :param double_q: (bool) True if using double q-learning :param n_step_q: (int) 'n' value used for n-step q-learning :return: bottom_up_maps: keypoint gaussian masks bottom_up_features: bottom-up keypoint features """ # unpacking elements from sampled trajectories from buffer obses_tm1, a_tm1, r_t, dones = inputs[0][0], inputs[0][1], inputs[0][ 2], inputs[0][3] obses_tm1 = tf.cast(obses_tm1, dtype=tf.float32) / 255.0 # (batch, T, H, W) # reshaping obs tensor (batch, T, H, W, C) -> (batch*T, H, W, C) obses_tm1_shape = obses_tm1.shape obses_tm1 = tf.reshape(obses_tm1, [ obses_tm1_shape[0] * obses_tm1_shape[1], obses_tm1_shape[2], obses_tm1_shape[3], obses_tm1_shape[4] ]) # 1 single forward pass of kpt-module for T-steps of frames vis_forward_start = time.time() bottom_up_maps, encoder_features, kpt_centers = vision_forward_pass( obses_tm1, vision_model_dict, lsp_layers, kp_type, patch_sizes, img_size) # reshaping tensors from (b*T, ...) -> (b, T, ...) bup_map_shape = bottom_up_maps.shape bottom_up_maps = tf.reshape(bottom_up_maps, [ obses_tm1_shape[0], obses_tm1_shape[1], bup_map_shape[1], bup_map_shape[2], bup_map_shape[3] ]) enc_feat_shape = encoder_features.shape encoder_features = tf.reshape(encoder_features, [ obses_tm1_shape[0], obses_tm1_shape[1], enc_feat_shape[1], enc_feat_shape[2], enc_feat_shape[3] ]) kpt_c_shape = kpt_centers.shape kpt_centers = tf.reshape(kpt_centers, [ obses_tm1_shape[0], obses_tm1_shape[1], kpt_c_shape[1], kpt_c_shape[2] ]) # splitting outputs into 2 parts targets = (1:T) and qs = (0:T-1) bottom_up_maps_tm1, bottom_up_maps_t = bottom_up_maps[:, n_step_q: -1, :, :, :], bottom_up_maps[:, n_step_q + 1:, :, :, :] encoder_features_tm1, encoder_features_t = encoder_features[:, n_step_q: -1, :, :, :], encoder_features[:, n_step_q + 1:, :, :, :] kpt_centers_tm1, kpt_centers_t = kpt_centers[:, n_step_q: -1, :, :], kpt_centers[:, n_step_q + 1:, :, :] # collecting a_tm1, r_t and dones for n'th step bootstrapping a_tm1, r_t = tf.cast(a_tm1, dtype=tf.int32), tf.cast(r_t, dtype=tf.float32) a_tm1, r_t = a_tm1[:, n_step_q:-1, :], r_t[:, 0:-1, :] dones = tf.cast(dones, dtype=tf.float32) dones = dones[:, n_step_q + 1:, 1] # dones for q_t's # switching batch and time axis to align all inputs i.e. (T, b, ..) -> (b, T, ..) a_tm1 = tf.transpose(a_tm1, perm=[1, 0, 2]) dones = tf.transpose(dones, perm=[1, 0]) # reshaping tensors again (ugh!) (b, T-1, ...) -> (b*(T-1), ...) bup_tm1_shape = bottom_up_maps_tm1.shape bottom_up_maps_tm1 = tf.reshape( bottom_up_maps_tm1, [-1, bup_tm1_shape[2], bup_tm1_shape[3], bup_tm1_shape[4]]) bottom_up_maps_t = tf.reshape(bottom_up_maps_t, bottom_up_maps_tm1.shape) enc_tm1_shape = encoder_features_tm1.shape encoder_features_tm1 = tf.reshape( encoder_features_tm1, [-1, enc_tm1_shape[2], enc_tm1_shape[3], enc_tm1_shape[4]]) encoder_features_t = tf.reshape(encoder_features_t, encoder_features_tm1.shape) kptc_tm1_shape = kpt_centers_tm1.shape kpt_centers_tm1 = tf.reshape(kpt_centers_tm1, [-1, kptc_tm1_shape[2], kptc_tm1_shape[3]]) kpt_centers_t = tf.reshape(kpt_centers_t, kpt_centers_tm1.shape) # compute keypoint encodings kpts_features_tm1 = encode_keypoints( bottom_up_maps_tm1, encoder_features_tm1, kpt_centers_tm1, mask_threshold, kp_type, kpt_encoder_type, mp_steps, True, pos_net=agent_model_dict.get("pos_net"), kpt_encoder=agent_model_dict.get("kpt_encoder"), node_encoder=agent_model_dict.get( "node_enc")) # passes none if not available kpts_features_t = encode_keypoints( bottom_up_maps_t, encoder_features_t, kpt_centers_t, mask_threshold, kp_type, kpt_encoder_type, mp_steps, True, pos_net=target_agent_model_dict.get("pos_net"), kpt_encoder=target_agent_model_dict.get("kpt_encoder"), node_encoder=target_agent_model_dict.get( "node_enc")) # passes none if not available # reshaping back the time axis (b*T, dims) -> (b, T, dims) kpts_features_tm1 = tf.expand_dims(kpts_features_tm1, axis=1) kpts_tm1_shape = kpts_features_tm1.shape kpts_features_tm1 = tf.reshape( kpts_features_tm1, [batch_size, window_size, kpts_tm1_shape[-1]]) kpts_features_t = tf.expand_dims(kpts_features_t, axis=1) kpts_t_shape = kpts_features_t.shape kpts_features_t = tf.reshape(kpts_features_t, [batch_size, window_size, kpts_t_shape[-1]]) # RNN computation q_tm1_seq = [] q_t_seq = [] q_t_selector_seq = [] # reset lstm state at start of update as in R-DQN random updates c_tm1 = tf.Variable(tf.zeros((batch_size, agent_size)), trainable=True) h_tm1 = tf.Variable(tf.zeros((batch_size, agent_size)), trainable=True) h_t_sel = tf.Variable(tf.zeros((batch_size, agent_size)), trainable=True) c_t_sel = tf.Variable(tf.zeros((batch_size, agent_size)), trainable=True) h_t = tf.Variable(tf.zeros((batch_size, agent_size)), trainable=False) # td_targets c_t = tf.Variable(tf.zeros((batch_size, agent_size)), trainable=False) # td_targets rnn_unroll_start = time.time() # RNN unrolling for seq_idx in tf.range(window_size): s_tm1 = kpts_features_tm1[:, seq_idx, :] s_t = kpts_features_t[:, seq_idx, :] # double_q action selection step if double_q: q_t_selector, h_t_sel, c_t_sel = agent_model_dict["agent_net"]( s_t, [h_t_sel, c_t_sel], training=True) q_t_selector_seq.append(q_t_selector) q_tm1, h_tm1, c_tm1 = agent_model_dict["agent_net"](s_tm1, [h_tm1, c_tm1], training=True) q_tm1_seq.append(q_tm1) q_t, h_t, c_t = target_agent_model_dict["agent_net"](s_t, [h_t, c_t], training=False) q_t_seq.append(q_t) # print("RNN for loop unrolling took %s" % (time.time() - rnn_unroll_start)) q_tm1 = tf.convert_to_tensor(q_tm1_seq, dtype=tf.float32) q_t = tf.convert_to_tensor(q_t_seq, dtype=tf.float32) # compute cumm. rew for 'n' steps if n_step_q > 1: l = tf.constant(np.array(list(range(n_step_q))), dtype=tf.float32) discounts = tf.math.pow(gamma, l) # slice r_t [b, T] into moving windows of [b, t-k, k] # cumsum over k steps r_t = tf.transpose(r_t, perm=[1, 0, 2]) r_t_sliced = tf.convert_to_tensor( [r_t[t:t + n_step_q, :, :] for t in range(window_size)], dtype=tf.float32) r_t_sliced = tf.squeeze(tf.transpose(r_t_sliced, perm=[0, 2, 1, 3])) r_t_sl_shape = r_t_sliced.shape # reshape (batch, T, n) -> (batch*T, n) r_t_sliced = tf.reshape( r_t_sliced, [r_t_sl_shape[0] * r_t_sl_shape[1], r_t_sl_shape[2]]) # r_t_slices [T*batch, n_steps] x discounts [n_steps, 1] r_t = tf.linalg.matvec(r_t_sliced, discounts) r_t = tf.reshape(r_t, [r_t_sl_shape[0], r_t_sl_shape[1]]) # reshape again to make tensors compatible with trfl API q_tm1_shape = q_tm1.shape q_tm1 = tf.reshape(q_tm1, [q_tm1_shape[0] * q_tm1_shape[1], q_tm1_shape[2]]) q_t = tf.reshape(q_t, [q_tm1_shape[0] * q_tm1_shape[1], q_tm1_shape[2]]) a_tm1_shape = a_tm1.shape a_tm1 = tf.squeeze( tf.reshape(a_tm1, [a_tm1_shape[0] * a_tm1_shape[1], a_tm1_shape[2]])) r_t_shape = r_t.shape r_t = tf.reshape(r_t, [r_t_shape[0] * r_t_shape[1]]) dones_shape = dones.shape dones = tf.reshape(dones, [dones_shape[0] * dones_shape[1]]) p_cont = 0.0 if n_step_q == 1: # discount factor (at t=1) for bootstrapped value p_cont = tf.math.multiply(tf.ones((dones.shape)) - dones, gamma) elif n_step_q > 1: # discount factor (at t=n+1) accordingly for bootstrapped value p_cont = tf.math.multiply( tf.ones((dones.shape)) - dones, tf.math.pow(gamma, n_step_q)) loss, extra = 0.0, None if not double_q: loss, extra = trfl.qlearning(q_tm1, a_tm1, r_t, p_cont, q_t) elif double_q: q_t_selector = tf.convert_to_tensor(q_t_selector_seq, dtype=tf.float32) q_t_selector = tf.reshape( q_t_selector, [q_tm1_shape[0] * q_tm1_shape[1], q_tm1_shape[2]]) loss, extra = trfl.double_qlearning(q_tm1, a_tm1, r_t, p_cont, q_t, q_t_selector) # average over batch_dim = (batch*time) loss = tf.reduce_mean(loss, axis=0) # print("Inside q_learning bellman updates took %4.5f" % (time.time() - q_backup_start)) return loss, extra
def main(unused_argv): ''' check path ''' if FLAGS.data_dir == '' or not os.path.exists(FLAGS.data_dir): raise ValueError('invalid data directory {}'.format(FLAGS.data_dir)) if FLAGS.output_dir == '': raise ValueError('invalid output directory {}'.format( FLAGS.output_dir)) elif not os.path.exists(FLAGS.output_dir): os.makedirs(FLAGS.output_dir) event_log_dir = os.path.join(FLAGS.output_dir, '') checkpoint_path = os.path.join(FLAGS.output_dir, 'model.ckpt') ''' setup summaries ''' summ = Summaries() ''' setup the game environment ''' filenames_train = glob.glob( os.path.join(FLAGS.data_dir, 'train-{}'.format(FLAGS.sampling_rate), '*.mat')) filenames_val = glob.glob( os.path.join(FLAGS.data_dir, 'val-{}'.format(FLAGS.sampling_rate), '*.mat')) game_env_train = Env(decay=FLAGS.decay) game_env_val = Env(decay=FLAGS.decay) game_actions = list(game_env_train.actions.keys()) ''' setup the transition table for experience replay ''' stateDim = [FLAGS.num_chans, FLAGS.num_points] transition_args = { 'batchSize': FLAGS.batch_size, 'stateDim': stateDim, 'numActions': len(game_actions), 'maxSize': FLAGS.replay_memory, } transitions = TransitionMemory(transition_args) ''' setup agent ''' s_placeholder = tf.placeholder(tf.float32, [FLAGS.batch_size] + stateDim, 's_placeholder') s2_placeholder = tf.placeholder(tf.float32, [FLAGS.batch_size] + stateDim, 's2_placeholder') a_placeholder = tf.placeholder(tf.int32, [FLAGS.batch_size], 'a_placeholder') r_placeholder = tf.placeholder(tf.float32, [FLAGS.batch_size], 'r_placeholder') pcont_t = tf.constant(FLAGS.discount, tf.float32, [FLAGS.batch_size]) network = Model(FLAGS.batch_size, len(game_actions), FLAGS.num_chans, FLAGS.sampling_rate, \ FLAGS.num_filters, FLAGS.num_recurs, FLAGS.pooling_stride, name = "network") target_network = Model(FLAGS.batch_size, len(game_actions), FLAGS.num_chans, FLAGS.sampling_rate,\ FLAGS.num_filters, FLAGS.num_recurs, FLAGS.pooling_stride, name = "target_n") q = network(s_placeholder) q2 = target_network(s2_placeholder) q_selector = network(s2_placeholder) loss, q_learning = trfl.double_qlearning(q, a_placeholder, r_placeholder, pcont_t, q2, q_selector) synchronizer = Synchronizer(network, target_network) sychronize_ops = synchronizer() training_variables = network.variables opt = Adam(FLAGS.learning_rate, lr_decay=FLAGS.lr_decay, lr_decay_steps=FLAGS.lr_decay_steps, lr_decay_factor=FLAGS.lr_decay_factor, clip=True) reduced_loss = tf.reduce_mean(loss) graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) total_regularization_loss = tf.reduce_sum(graph_regularizers) total_loss = reduced_loss + total_regularization_loss update_op = opt(total_loss, var_list=training_variables) summ_loss_op = tf.summary.scalar('loss', total_loss) state_placeholder = tf.placeholder(tf.float32, [1] + stateDim, 'state_placeholder') decayed_ep_placeholder = tf.placeholder(tf.float32, [], 'decayed_ep_placeholder') action_tensor_egreedy = eGreedy(state_placeholder, network, len(game_actions), decayed_ep_placeholder, FLAGS.debug) action_tensor_greedy = greedy(state_placeholder, network) ''' setup the training process ''' episode_reward_placeholder = tf.placeholder(tf.float32, [], "episode_reward_placeholder") average_reward_placeholder = tf.placeholder(tf.float32, [], "average_reward_placeholder") summ.register('train', 'episode_reward_train', episode_reward_placeholder) summ.register('train', 'average_reward_train', average_reward_placeholder) summ.register('val', 'episode_reward_val', episode_reward_placeholder) summ.register('val', 'average_reward_val', average_reward_placeholder) total_reward_train = 0 average_reward_train = 0 total_reward_val = 0 average_reward_val = 0 ''' gathering summary operators ''' train_summ_op = summ('train') val_summ_op = summ('val') ''' setup the training process ''' transitions.empty() # print("game_actions -> {}".format(game_actions)) writer = tf.summary.FileWriter(event_log_dir, tf.get_default_graph()) saver = tf.train.Saver(training_variables) config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) assert (FLAGS.gpus != ''), 'invalid GPU specification' config.gpu_options.visible_device_list = FLAGS.gpus with tf.Session(config=config) as sess: sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) val_step = 0 for step in range(FLAGS.steps): print("Iteration: {}".format(step)) game_env_train.reset(filenames_train[np.random.randint( 0, len(filenames_train))]) last_state = None last_state_assigned = False episode_reward = 0 action_index = (len(game_actions) >> 2) for estep in range(FLAGS.eval_steps): # print("Evaluation step: {}".format(estep)) # print("{} - measured RT: {}".format(estep, game_env_train.measured_rt)) # print("{} - predicted RT: {}".format(estep, game_env_train.predicted_rt)) # print("{} - action -> {}".format(estep, game_actions[action])) state, reward, terminal = game_env_train.step( game_actions[action_index]) # game over? if terminal: break episode_reward += reward # Store transition s, a, r, t # if last_state_assigned and reward: if last_state_assigned: # print("reward -> {}".format(reward)) # print("action -> {}".format(game_actions[last_action])) transitions.add(last_state, last_action, reward, last_terminal) # Select action # decayed_ep = FLAGS.testing_ep decayed_ep = max(0.1, (FLAGS.steps - step) / FLAGS.steps * FLAGS.ep) if not terminal: action_index = sess.run(action_tensor_egreedy, feed_dict={ state_placeholder: np.expand_dims(state, axis=0), decayed_ep_placeholder: decayed_ep }) else: action_index = 0 # Do some Q-learning updates if estep > FLAGS.learn_start and estep % FLAGS.update_freq == 0: summ_str = None for _ in range(FLAGS.n_replay): if transitions.size > FLAGS.batch_size: s, a, r, s2 = transitions.sample() summ_str, _ = sess.run( [summ_loss_op, update_op], feed_dict={ s_placeholder: s, a_placeholder: a, r_placeholder: r, s2_placeholder: s2 }) if summ_str: writer.add_summary(summ_str, step * FLAGS.eval_steps + estep) last_state = state last_state_assigned = True last_action = action_index last_terminal = terminal if estep > FLAGS.learn_start and estep % FLAGS.target_q == 0: # print("duplicate model parameters") sess.run(sychronize_ops) total_reward_train += episode_reward average_reward_train = total_reward_train / (step + 1) train_summ_str = sess.run(train_summ_op, feed_dict={ episode_reward_placeholder: episode_reward, average_reward_placeholder: average_reward_train }) writer.add_summary(train_summ_str, step) if FLAGS.validation and step % FLAGS.validation_interval == 0: game_env_val.reset(filenames_val[0]) episode_reward = 0 count = 0 action_index = (len(game_actions) >> 2) while True: # print("Evaluation step: {}".format(count)) # print("action -> {}".format(game_actions[action_index])) state, reward, terminal = game_env_val.step( game_actions[action_index]) # game over? if terminal: break episode_reward += reward if not terminal: action_index = sess.run(action_tensor_greedy, feed_dict={ state_placeholder: np.expand_dims(state, axis=0) }) action_index = np.squeeze(action_index) # print('state -> {}'.format(state)) # print('action_index -> {}'.format(action_index)) else: action_index = 0 count += 1 total_reward_val += episode_reward average_reward_val = total_reward_val / (val_step + 1) val_step += 1 val_summ_str = sess.run(val_summ_op, feed_dict={ episode_reward_placeholder: episode_reward, average_reward_placeholder: average_reward_val }) writer.add_summary(val_summ_str, step) tf.logging.info('Saving model.') saver.save(sess, checkpoint_path) tf.logging.info('Training complete') writer.close()
def _step(self) -> Dict[str, tf.Tensor]: """Do a step of SGD and update the priorities.""" # Pull out the data needed for updates/priorities. inputs = next(self._iterator) o_tm1, a_tm1, r_t, d_t, o_t = inputs.data keys, probs = inputs.info[:2] with tf.GradientTape() as tape: # Evaluate our networks. q_tm1 = self._network(o_tm1) q_t_value = self._target_network(o_t) q_t_selector = self._network(o_t) # The rewards and discounts have to have the same type as network values. r_t = tf.cast(r_t, q_tm1.dtype) r_t = tf.clip_by_value(r_t, -1., 1.) d_t = tf.cast(d_t, q_tm1.dtype) * tf.cast(self._discount, q_tm1.dtype) # Compute the loss. _, extra = trfl.double_qlearning(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector) loss = losses.huber(extra.td_error, self._huber_loss_parameter) if self._alpha: policy_probs = self._emp_policy.lookup([str(o) for o in o_tm1]) push_down = tf.reduce_logsumexp( q_tm1 * self._tr, axis=1) / self._tr # soft-maximum of the q func push_up = tf.reduce_sum( policy_probs * q_tm1, axis=1) # expected q value under behavioural policy cql_loss = loss + self._alpha * (push_down - push_up) else: cql_loss = loss cql_loss = tf.reduce_mean(cql_loss, axis=0) # Do a step of SGD. gradients = tape.gradient(cql_loss, self._network.trainable_variables) self._optimizer.apply(gradients, self._network.trainable_variables) # Update the priorities in the replay buffer. if self._replay_client: priorities = tf.cast(tf.abs(extra.td_error), tf.float64) self._replay_client.update_priorities( table=adders.DEFAULT_PRIORITY_TABLE, keys=keys, priorities=priorities) # Periodically update the target network. if tf.math.mod(self._counter.get_counts()['learner_steps'], self._target_update_period) == 0: for src, dest in zip(self._network.variables, self._target_network.variables): dest.assign(src) # Report loss & statistics for logging. fetches = { 'critic_loss': tf.reduce_mean(loss, axis=0), 'q_variance': tf.reduce_mean(tf.math.reduce_variance(q_tm1, axis=1), axis=0), 'q_average': tf.reduce_mean(q_tm1) } if self._alpha: fetches.update({ 'push_up': tf.reduce_mean(push_up, axis=0), 'push_down': tf.reduce_mean(push_down, axis=0), 'regularizer': tf.reduce_mean(push_down - push_up, axis=0), 'cql_loss': cql_loss, }) return fetches