def build_graph(self, image, label, indices): """ The default tower function. """ image = self.image_preprocess(image) assert self.data_format == 'NCHW' image = tf.transpose(image, [0, 3, 1, 2]) with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): # BatchNorm always comes with trouble. We use the testing mode of it during attack. with freeze_collection([tf.GraphKeys.UPDATE_OPS]), argscope(BatchNorm, training=False): image, target_label = self.attacker.attack(image, label, self.get_logits) image = tf.stop_gradient(image, name='adv_training_sample') logits = self.get_logits(image) loss = ImageNetModel.compute_loss_and_error( logits, label, label_smoothing=self.label_smoothing) AdvImageNetModel.compute_attack_success(logits, target_label) if not self.training: return wd_loss = regularize_cost(self.weight_decay_pattern, tf.contrib.layers.l2_regularizer(self.weight_decay), name='l2_regularize_loss') add_moving_summary(loss, wd_loss) total_cost = tf.add_n([loss, wd_loss], name='cost') if self.loss_scale != 1.: logger.info("Scaling the total loss by {} ...".format(self.loss_scale)) return total_cost * self.loss_scale else: return total_cost
def _build_graph(self, inputs): comb_state, action, reward, isOver = inputs comb_state = tf.cast(comb_state, tf.float32) state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, self.channel], name='state') self.predict_value = self._get_DQN_prediction(state) if not get_current_tower_context().is_training: return reward = tf.clip_by_value(reward, -1, 1) next_state = tf.slice(comb_state, [0, 0, 0, 1], [-1, -1, -1, self.channel], name='next_state') action_onehot = tf.one_hot(action, self.num_actions, 1.0, 0.0) pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N, max_pred_reward = tf.reduce_mean(tf.reduce_max(self.predict_value, 1), name='predict_reward') summary.add_moving_summary(max_pred_reward) with tf.variable_scope('target'), \ collection.freeze_collection([tf.GraphKeys.TRAINABLE_VARIABLES]): targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA if self.method != 'Double': # DQN best_v = tf.reduce_max(targetQ_predict_value, 1) # N, else: # Double-DQN sc = tf.get_variable_scope() with tf.variable_scope(sc, reuse=True): next_predict_value = self._get_DQN_prediction(next_state) self.greedy_choice = tf.argmax(next_predict_value, 1) # N, predict_onehot = tf.one_hot(self.greedy_choice, self.num_actions, 1.0, 0.0) best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1) target = reward + (1.0 - tf.cast( isOver, tf.float32)) * self.gamma * tf.stop_gradient(best_v) self.cost = tf.reduce_mean(symbf.huber_loss(target - pred_action_value), name='cost') summary.add_param_summary( ('conv.*/W', ['histogram', 'rms']), ('fc.*/W', ['histogram', 'rms'])) # monitor all W summary.add_moving_summary(self.cost)
def _build_graph(self, inputs): comb_state, action, reward, isOver, action_o = inputs self.batch_size = tf.shape(comb_state)[0] backward_offset = ((self.channel) - self.update_step) action = tf.slice(action, [0, backward_offset], [-1, self.update_step]) reward = tf.slice(reward, [0, backward_offset], [-1, self.update_step]) isOver = tf.slice(isOver, [0, backward_offset], [-1, self.update_step]) action_o = tf.slice(action_o, [0, backward_offset, 0], [-1, self.update_step, self.num_agents]) action = tf.reshape(action, (self.batch_size * self.update_step, )) reward = tf.reshape(reward, (self.batch_size * self.update_step, )) isOver = tf.reshape(isOver, (self.batch_size * self.update_step, )) action_o = tf.reshape( action_o, (self.batch_size * self.update_step, self.num_agents)) comb_state = tf.cast(comb_state, tf.float32) state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, self.channel], name='state') self.predict_value, pi_value, self.q_rnn_state_out, self.pi_rnn_state_out = self._get_DQN_prediction( state) if not get_current_tower_context().is_training: return reward = tf.clip_by_value(reward, -1, 1) next_state = tf.slice(comb_state, [0, 0, 0, 1], [-1, -1, -1, self.channel], name='next_state') action_onehot = tf.one_hot(action, self.num_actions, 1.0, 0.0) pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N, max_pred_reward = tf.reduce_mean(tf.reduce_max(self.predict_value, 1), name='predict_reward') summary.add_moving_summary(max_pred_reward) with tf.variable_scope('target'), \ collection.freeze_collection([tf.GraphKeys.TRAINABLE_VARIABLES]): targetQ_predict_value, target_pi_value, _, _ = self._get_DQN_prediction( next_state) # NxA if self.method != 'Double': # DQN best_v = tf.reduce_max(targetQ_predict_value, 1) # N, else: # Double-DQN sc = tf.get_variable_scope() with tf.variable_scope(sc, reuse=True): next_predict_value, next_pi_value, _, _ = self._get_DQN_prediction( next_state) self.greedy_choice = tf.argmax(next_predict_value, 1) # N, predict_onehot = tf.one_hot(self.greedy_choice, self.num_actions, 1.0, 0.0) best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1) target = reward + (1.0 - tf.cast( isOver, tf.float32)) * self.gamma * tf.stop_gradient(best_v) # q cost q_cost = (symbf.huber_loss(target - pred_action_value)) # pi cost action_os = tf.unstack(action_o, self.num_agents, axis=1) action_o_one_hots = [] for o in action_os: action_o_one_hots.append(tf.one_hot(o, self.num_actions, 1.0, 0.0)) pi_costs = [] for i, o in enumerate(action_o_one_hots): scale = 1.0 # Coop-only: disable opponent loss if self.mt_type == 'coop-only' and i > 0: scale = 0.0 # Opponent-only: disable collaborator loss if self.mt_type == 'opponent-only' and i == 0: scale = 0.0 pi_costs.append(scale * tf.nn.softmax_cross_entropy_with_logits( labels=o, logits=pi_value[i])) pi_cost = self.lamb * tf.add_n(pi_costs) if self.reg: reg_coff = tf.stop_gradient(tf.sqrt( 1.0 / (tf.reduce_mean(pi_cost) + 1e-9)), name='reg') self.cost = tf.reduce_mean(reg_coff * q_cost + pi_cost) summary.add_moving_summary(reg_coff) else: self.cost = tf.reduce_mean(q_cost + pi_cost) summary.add_param_summary( ('conv.*/W', ['histogram', 'rms']), ('fc.*/W', ['histogram', 'rms'])) # monitor all W summary.add_moving_summary(self.cost) summary.add_moving_summary(tf.reduce_mean(pi_cost, name='pi_cost')) summary.add_moving_summary(tf.reduce_mean(q_cost, name='q_cost')) for i, o_t in enumerate(action_os): pred = tf.argmax(pi_value[i], axis=1) summary.add_moving_summary( tf.contrib.metrics.accuracy(pred, o_t, name='acc-%d' % i))
def maybe_freeze_updates(enable): if enable: with freeze_collection([tf.GraphKeys.UPDATE_OPS]): yield else: yield