def virtual_adversarial_loss(x, logit, is_training=True, name="vat_loss"): r_vadv = generate_virtual_adversarial_perturbation(x, logit, is_training=is_training) logit = tf.stop_gradient(logit) logit_p = logit logit_m = forward(x + r_vadv, update_batch_stats=False, is_training=is_training) loss = L.kl_divergence_with_logit(logit_p, logit_m) return tf.identity(loss, name=name)
def distance(q_logit, p_logit): if FLAGS.dist == 'KL': return L.kl_divergence_with_logit(q_logit, p_logit) elif FLAGS.dist == 'FM': return L.mean_feature_matching(q_logit, p_logit) else: raise NotImplementedError
def virtual_adversarial_loss(x, u, logit, is_training=True, name="vat_loss"): u_prime = generate_virtual_adversarial_perturbation( x, u, logit, is_training=is_training) logit = tf.stop_gradient(logit) logit_p = logit logit_m = forward(x + FLAGS.epsilon * u_prime, update_batch_stats=False, is_training=is_training) loss = L.kl_divergence_with_logit(logit_p, logit_m) / FLAGS.epsilon return tf.identity(loss, name=name), u_prime
def generate_virtual_adversarial_dropout_mask(x, logit, is_training=True): logit_m, init_mask = CNN.logit(x, None, is_training=True, update_batch_stats=False, stochastic=True, seed=1234) dist = L.kl_divergence_with_logit(logit_m, logit) mask_grad = tf.stop_gradient( tf.gradients(dist, [init_mask], aggregation_method=2)[0]) return flipping_algorithm(init_mask, mask_grad)
def generate_virtual_adversarial_perturbation(x, logit, is_training=True): d = tf.random_normal(shape=tf.shape(x)) for _ in range(FLAGS.num_power_iterations): d = FLAGS.xi * get_normalized_vector(d) logit_p = logit logit_m = forward(x + d, update_batch_stats=False, is_training=is_training) dist = L.kl_divergence_with_logit(logit_p, logit_m) grad = tf.gradients(dist, [d], aggregation_method=2)[0] d = tf.stop_gradient(grad) return FLAGS.epsilon * get_normalized_vector(d)
def virtual_adversarial_dropout_loss(x, logit, is_training=True, name="vadt_loss"): adv_mask = generate_virtual_adversarial_dropout_mask( x, logit, is_training=is_training) logit_p = logit logit_m, _ = CNN.logit(x, adv_mask, is_training=True, update_batch_stats=True, stochastic=True, seed=1234) loss = L.kl_divergence_with_logit(logit_p, logit_m) return tf.identity(loss, name=name)
def generate_virtual_adversarial_perturbation(x, u, logit, is_training=True): d = u for _ in range(FLAGS.num_power_iterations): d = FLAGS.xi * d logit_p = logit logit_m = forward(x + d, update_batch_stats=False, is_training=is_training) # TODO use L2 instead. not just here but in virtual_adversarial_loss dist = L.kl_divergence_with_logit(logit_p, logit_m) grad = tf.gradients(dist, [d], aggregation_method=2)[0] d = tf.stop_gradient(grad) d = get_normalized_vector(d) return d
def generate_virtual_adversarial_perturbation(x, logit, is_training=True): d = tf.random_normal(shape=tf.shape(x)) if FLAGS.xi_stddev == 0: xi = FLAGS.xi else: xi = tf.clip_by_value(tf.abs( tf.random_normal(shape=x.shape[0:1], mean=0, stddev=FLAGS.xi_stddev)), clip_value_min=FLAGS.xi, clip_value_max=10) for _ in range(FLAGS.num_power_iterations): d = FLAGS.xi * get_normalized_vector(d) logit_p = logit logit_m = forward(x + d, update_batch_stats=False, is_training=is_training) dist = L.kl_divergence_with_logit(logit_p, logit_m) grad = tf.gradients(dist, [d], aggregation_method=2)[0] d = tf.stop_gradient(grad) return FLAGS.epsilon * get_normalized_vector(d)
def build_training_graph(x_1, x_2, y, ul_x_1, ul_x_2, lr, mom, lamb): global_step = tf.get_variable( name="global_step", shape=[], dtype=tf.float32, initializer=tf.constant_initializer(0.0), trainable=False, ) logit = adt.forward(x_1, update_batch_stats=True) nll_loss = L.ce_loss(logit, y) with tf.variable_scope(tf.get_variable_scope(), reuse=True): if FLAGS.method == 'VAdD-KL': ul_logit = adt.forward(ul_x_1) ul_adt_logit = adt.forward_adv_drop(ul_x_2, ul_logit, FLAGS.delta, is_training=True, mode=FLAGS.method) additional_loss = L.kl_divergence_with_logit( ul_logit, ul_adt_logit) #*4.0 ent_loss = L.entropy_y_x(ul_logit) loss = nll_loss + lamb * additional_loss + ent_loss elif FLAGS.method == 'VAT+VAdD-KL': ul_logit = adt.forward(ul_x_1) ul_adt_logit = adt.forward_adv_drop(ul_x_2, ul_logit, FLAGS.delta, is_training=True, mode='VAdD-KL') additional_loss = L.kl_divergence_with_logit( ul_logit, ul_adt_logit) #*4.0 ent_loss = L.entropy_y_x(ul_logit) vat_loss = adt.virtual_adversarial_loss(ul_x_1, ul_logit) loss = nll_loss + lamb * additional_loss + vat_loss + ent_loss elif FLAGS.method == 'VAdD-QE': ul_logit = adt.forward(ul_x_1, update_batch_stats=True) ul_adt_logit = adt.forward_adv_drop(ul_x_2, ul_logit, FLAGS.delta, is_training=True, update_batch_stats=True, mode=FLAGS.method) additional_loss = L.qe_loss(ul_logit, ul_adt_logit) #*4.0 ent_loss = L.entropy_y_x(ul_logit) loss = nll_loss + lamb * additional_loss + ent_loss elif FLAGS.method == 'VAT+VAdD-QE': ul_logit = adt.forward(ul_x_1, update_batch_stats=True) ul_adt_logit = adt.forward_adv_drop(ul_x_2, ul_logit, FLAGS.delta, is_training=True, update_batch_stats=True, mode='VAdD-QE') additional_loss = L.qe_loss(ul_logit, ul_adt_logit) #*4.0 ent_loss = L.entropy_y_x(ul_logit) vat_loss = adt.virtual_adversarial_loss(ul_x_1, ul_logit) loss = nll_loss + lamb * additional_loss + vat_loss + ent_loss elif FLAGS.method == 'VAT': ul_logit = adt.forward(ul_x_1, update_batch_stats=False) ent_loss = L.entropy_y_x(ul_logit) # + L.entropy_y_x(ul_adt_logit) vat_loss = adt.virtual_adversarial_loss(ul_x_1, ul_logit) loss = nll_loss + vat_loss + ent_loss elif FLAGS.method == 'Pi': ul_logit = adt.forward(ul_x_1, update_batch_stats=True) ul_adt_logit = adt.forward(ul_x_2, update_batch_stats=True) additional_loss = L.qe_loss(ul_logit, ul_adt_logit) #*4.0 ent_loss = L.entropy_y_x(ul_logit) loss = nll_loss + lamb * additional_loss + ent_loss elif FLAGS.method == 'baseline': logit = vat.forward(x_1) nll_loss = L.ce_loss(logit, y) scope = tf.get_variable_scope() scope.reuse_variables() additional_loss = 0 else: raise NotImplementedError with tf.variable_scope(tf.get_variable_scope(), reuse=False): opt = tf.train.AdamOptimizer(learning_rate=lr, beta1=mom, beta2=0.999) tvars = tf.trainable_variables() grads_and_vars = opt.compute_gradients(loss, tvars) train_op = opt.apply_gradients(grads_and_vars, global_step=global_step) return loss, train_op, global_step
def build_training_graph(x_1, x_2, y, lr, mom, lamb): global_step = tf.get_variable( name="global_step", shape=[], dtype=tf.float32, initializer=tf.constant_initializer(0.0), trainable=False, ) logit = adt.forward(x_1) nll_loss = L.ce_loss(logit, y) with tf.variable_scope(tf.get_variable_scope(), reuse=True): if FLAGS.method == 'SAdD': adt_logit = adt.forward_adv_drop(x_2, y, FLAGS.delta, is_training=True) additional_loss = L.ce_loss(adt_logit, y) loss = nll_loss + lamb * additional_loss elif FLAGS.method == 'VAdD-KL': logit_p = logit adt_logit = adt.forward_adv_drop(x_2, logit_p, FLAGS.delta, is_training=True, mode=FLAGS.method) additional_loss = L.kl_divergence_with_logit(logit_p, adt_logit) loss = nll_loss + lamb * additional_loss elif FLAGS.method == 'VAdD-QE': logit_p = logit adt_logit = adt.forward_adv_drop(x_2, logit_p, FLAGS.delta, is_training=True, mode=FLAGS.method) additional_loss = L.qe_loss(adt_logit, logit_p) loss = nll_loss + lamb * additional_loss elif FLAGS.method == 'VAT+VAdD-KL': logit_p = logit adt_logit = adt.forward_adv_drop(x_2, logit_p, FLAGS.delta, is_training=True, mode='VAdD-KL') additional_loss = L.kl_divergence_with_logit(logit_p, adt_logit) vat_loss = adt.virtual_adversarial_loss(x_1, logit_p) loss = nll_loss + lamb * additional_loss + vat_loss elif FLAGS.method == 'VAT+VAdD-QE': logit_p = logit adt_logit = adt.forward_adv_drop(x_2, logit_p, FLAGS.delta, is_training=True, mode='VAdD-QE') additional_loss = L.qe_loss(adt_logit, logit_p) vat_loss = adt.virtual_adversarial_loss(x_1, logit_p) loss = nll_loss + lamb * additional_loss + vat_loss elif FLAGS.method == 'VAT': logit_p = tf.stop_gradient(logit) logit_p = logit vat_loss = adt.virtual_adversarial_loss(x_1, logit_p) loss = nll_loss + vat_loss elif FLAGS.method == 'Pi': adt_logit = adt.forward(x_2) additional_loss = L.qe_loss(adt_logit, logit) loss = nll_loss + lamb * additional_loss elif FLAGS.method == 'baseline': additional_loss = 0 adt_masks = masks else: raise NotImplementedError with tf.variable_scope(tf.get_variable_scope(), reuse=False): opt = tf.train.AdamOptimizer(learning_rate=lr, beta1=mom) tvars = tf.trainable_variables() grads_and_vars = opt.compute_gradients(loss, tvars) train_op = opt.apply_gradients(grads_and_vars, global_step=global_step) return loss, train_op, global_step