def tower_func(image, label): assert not self.training image = self.image_preprocess(image) image = tf.transpose(image, [0, 3, 1, 2]) image, target_label = attacker.attack(image, label, self.get_logits) logits = self.get_logits(image) ImageNetModel.compute_loss_and_error(logits, label) # compute top-1 and top-5 AdvImageNetModel.compute_attack_success(logits, target_label)
def build_graph(self, image, label, indices): """ The default tower function. """ image = self.image_preprocess(image) assert self.data_format == 'NCHW' image = tf.transpose(image, [0, 3, 1, 2]) with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): # BatchNorm always comes with trouble. We use the testing mode of it during attack. with freeze_collection([tf.GraphKeys.UPDATE_OPS]), argscope(BatchNorm, training=False): image, target_label = self.attacker.attack(image, label, self.get_logits) image = tf.stop_gradient(image, name='adv_training_sample') logits = self.get_logits(image) loss = ImageNetModel.compute_loss_and_error( logits, label, label_smoothing=self.label_smoothing) AdvImageNetModel.compute_attack_success(logits, target_label) if not self.training: return wd_loss = regularize_cost(self.weight_decay_pattern, tf.contrib.layers.l2_regularizer(self.weight_decay), name='l2_regularize_loss') add_moving_summary(loss, wd_loss) total_cost = tf.add_n([loss, wd_loss], name='cost') if self.loss_scale != 1.: logger.info("Scaling the total loss by {} ...".format(self.loss_scale)) return total_cost * self.loss_scale else: return total_cost
def tower_func(image, label): assert not self.training image_orig = self.image_preprocess(image) image_orig = tf.transpose(image_orig, [0, 3, 1, 2]) if hasattr(self, 'palatte'): image_adv, target_label = attacker.attack( image_orig, label, self.get_logits_raw, self.palatte) else: image_adv, target_label = attacker.attack( image_orig, label, self.get_logits) logits = self.get_logits(image_adv) if save: from adv_tf_record import save_adv save_func = lambda image_orig, image_adv, label, target_label, logits: \ save_adv(image_orig, image_adv, label, target_label, logits, data_obj) image_orig, image_adv, label, target_label, logits = \ tf.py_func(save_func, [image_orig, image_adv, label, target_label, logits], [image_orig.dtype, image_adv.dtype, label.dtype, target_label.dtype, logits.dtype]) ImageNetModel.compute_loss_and_error( logits, label) # compute top-1 and top-5 AdvImageNetModel.compute_attack_success(logits, target_label)