Пример #1
0
def virtual_adversarial_loss(x, logit, is_training=True, name="vat_loss"):
    r_vadv = generate_virtual_adversarial_perturbation(x, logit, is_training=is_training)
    logit = tf.stop_gradient(logit)
    logit_p = logit
    logit_m = forward(x + r_vadv, update_batch_stats=False, is_training=is_training)
    loss = L.kl_divergence_with_logit(logit_p, logit_m)
    return tf.identity(loss, name=name)
Пример #2
0
def distance(q_logit, p_logit):
    if FLAGS.dist == 'KL':
        return L.kl_divergence_with_logit(q_logit, p_logit)
    elif FLAGS.dist == 'FM':
        return L.mean_feature_matching(q_logit, p_logit)
    else:
        raise NotImplementedError
Пример #3
0
def virtual_adversarial_loss(x, u, logit, is_training=True, name="vat_loss"):
    u_prime = generate_virtual_adversarial_perturbation(
        x, u, logit, is_training=is_training)
    logit = tf.stop_gradient(logit)
    logit_p = logit
    logit_m = forward(x + FLAGS.epsilon * u_prime,
                      update_batch_stats=False,
                      is_training=is_training)
    loss = L.kl_divergence_with_logit(logit_p, logit_m) / FLAGS.epsilon
    return tf.identity(loss, name=name), u_prime
def generate_virtual_adversarial_dropout_mask(x, logit, is_training=True):

    logit_m, init_mask = CNN.logit(x,
                                   None,
                                   is_training=True,
                                   update_batch_stats=False,
                                   stochastic=True,
                                   seed=1234)
    dist = L.kl_divergence_with_logit(logit_m, logit)
    mask_grad = tf.stop_gradient(
        tf.gradients(dist, [init_mask], aggregation_method=2)[0])
    return flipping_algorithm(init_mask, mask_grad)
Пример #5
0
def generate_virtual_adversarial_perturbation(x, logit, is_training=True):
    d = tf.random_normal(shape=tf.shape(x))

    for _ in range(FLAGS.num_power_iterations):
        d = FLAGS.xi * get_normalized_vector(d)
        logit_p = logit
        logit_m = forward(x + d, update_batch_stats=False, is_training=is_training)
        dist = L.kl_divergence_with_logit(logit_p, logit_m)
        grad = tf.gradients(dist, [d], aggregation_method=2)[0]
        d = tf.stop_gradient(grad)

    return FLAGS.epsilon * get_normalized_vector(d)
def virtual_adversarial_dropout_loss(x,
                                     logit,
                                     is_training=True,
                                     name="vadt_loss"):
    adv_mask = generate_virtual_adversarial_dropout_mask(
        x, logit, is_training=is_training)
    logit_p = logit
    logit_m, _ = CNN.logit(x,
                           adv_mask,
                           is_training=True,
                           update_batch_stats=True,
                           stochastic=True,
                           seed=1234)

    loss = L.kl_divergence_with_logit(logit_p, logit_m)
    return tf.identity(loss, name=name)
Пример #7
0
def generate_virtual_adversarial_perturbation(x, u, logit, is_training=True):
    d = u

    for _ in range(FLAGS.num_power_iterations):
        d = FLAGS.xi * d
        logit_p = logit
        logit_m = forward(x + d,
                          update_batch_stats=False,
                          is_training=is_training)
        # TODO use L2 instead. not just here but in virtual_adversarial_loss
        dist = L.kl_divergence_with_logit(logit_p, logit_m)
        grad = tf.gradients(dist, [d], aggregation_method=2)[0]
        d = tf.stop_gradient(grad)
        d = get_normalized_vector(d)

    return d
Пример #8
0
def generate_virtual_adversarial_perturbation(x, logit, is_training=True):
    d = tf.random_normal(shape=tf.shape(x))
    if FLAGS.xi_stddev == 0:
        xi = FLAGS.xi
    else:
        xi = tf.clip_by_value(tf.abs(
            tf.random_normal(shape=x.shape[0:1],
                             mean=0,
                             stddev=FLAGS.xi_stddev)),
                              clip_value_min=FLAGS.xi,
                              clip_value_max=10)
    for _ in range(FLAGS.num_power_iterations):
        d = FLAGS.xi * get_normalized_vector(d)
        logit_p = logit
        logit_m = forward(x + d,
                          update_batch_stats=False,
                          is_training=is_training)
        dist = L.kl_divergence_with_logit(logit_p, logit_m)
        grad = tf.gradients(dist, [d], aggregation_method=2)[0]
        d = tf.stop_gradient(grad)

    return FLAGS.epsilon * get_normalized_vector(d)
def build_training_graph(x_1, x_2, y, ul_x_1, ul_x_2, lr, mom, lamb):
    global_step = tf.get_variable(
        name="global_step",
        shape=[],
        dtype=tf.float32,
        initializer=tf.constant_initializer(0.0),
        trainable=False,
    )
    logit = adt.forward(x_1, update_batch_stats=True)
    nll_loss = L.ce_loss(logit, y)

    with tf.variable_scope(tf.get_variable_scope(), reuse=True):

        if FLAGS.method == 'VAdD-KL':

            ul_logit = adt.forward(ul_x_1)
            ul_adt_logit = adt.forward_adv_drop(ul_x_2,
                                                ul_logit,
                                                FLAGS.delta,
                                                is_training=True,
                                                mode=FLAGS.method)

            additional_loss = L.kl_divergence_with_logit(
                ul_logit, ul_adt_logit)  #*4.0
            ent_loss = L.entropy_y_x(ul_logit)

            loss = nll_loss + lamb * additional_loss + ent_loss

        elif FLAGS.method == 'VAT+VAdD-KL':

            ul_logit = adt.forward(ul_x_1)
            ul_adt_logit = adt.forward_adv_drop(ul_x_2,
                                                ul_logit,
                                                FLAGS.delta,
                                                is_training=True,
                                                mode='VAdD-KL')

            additional_loss = L.kl_divergence_with_logit(
                ul_logit, ul_adt_logit)  #*4.0
            ent_loss = L.entropy_y_x(ul_logit)
            vat_loss = adt.virtual_adversarial_loss(ul_x_1, ul_logit)

            loss = nll_loss + lamb * additional_loss + vat_loss + ent_loss

        elif FLAGS.method == 'VAdD-QE':
            ul_logit = adt.forward(ul_x_1, update_batch_stats=True)
            ul_adt_logit = adt.forward_adv_drop(ul_x_2,
                                                ul_logit,
                                                FLAGS.delta,
                                                is_training=True,
                                                update_batch_stats=True,
                                                mode=FLAGS.method)

            additional_loss = L.qe_loss(ul_logit, ul_adt_logit)  #*4.0
            ent_loss = L.entropy_y_x(ul_logit)

            loss = nll_loss + lamb * additional_loss + ent_loss

        elif FLAGS.method == 'VAT+VAdD-QE':
            ul_logit = adt.forward(ul_x_1, update_batch_stats=True)
            ul_adt_logit = adt.forward_adv_drop(ul_x_2,
                                                ul_logit,
                                                FLAGS.delta,
                                                is_training=True,
                                                update_batch_stats=True,
                                                mode='VAdD-QE')

            additional_loss = L.qe_loss(ul_logit, ul_adt_logit)  #*4.0
            ent_loss = L.entropy_y_x(ul_logit)
            vat_loss = adt.virtual_adversarial_loss(ul_x_1, ul_logit)

            loss = nll_loss + lamb * additional_loss + vat_loss + ent_loss

        elif FLAGS.method == 'VAT':

            ul_logit = adt.forward(ul_x_1, update_batch_stats=False)
            ent_loss = L.entropy_y_x(ul_logit)  # + L.entropy_y_x(ul_adt_logit)
            vat_loss = adt.virtual_adversarial_loss(ul_x_1, ul_logit)
            loss = nll_loss + vat_loss + ent_loss

        elif FLAGS.method == 'Pi':

            ul_logit = adt.forward(ul_x_1, update_batch_stats=True)
            ul_adt_logit = adt.forward(ul_x_2, update_batch_stats=True)
            additional_loss = L.qe_loss(ul_logit, ul_adt_logit)  #*4.0
            ent_loss = L.entropy_y_x(ul_logit)

            loss = nll_loss + lamb * additional_loss + ent_loss

        elif FLAGS.method == 'baseline':
            logit = vat.forward(x_1)
            nll_loss = L.ce_loss(logit, y)
            scope = tf.get_variable_scope()
            scope.reuse_variables()

            additional_loss = 0
        else:
            raise NotImplementedError

    with tf.variable_scope(tf.get_variable_scope(), reuse=False):
        opt = tf.train.AdamOptimizer(learning_rate=lr, beta1=mom, beta2=0.999)
        tvars = tf.trainable_variables()
        grads_and_vars = opt.compute_gradients(loss, tvars)
        train_op = opt.apply_gradients(grads_and_vars, global_step=global_step)
    return loss, train_op, global_step
Пример #10
0
def build_training_graph(x_1, x_2, y, lr, mom, lamb):
    global_step = tf.get_variable(
        name="global_step",
        shape=[],
        dtype=tf.float32,
        initializer=tf.constant_initializer(0.0),
        trainable=False,
    )
    logit = adt.forward(x_1)
    nll_loss = L.ce_loss(logit, y)

    with tf.variable_scope(tf.get_variable_scope(), reuse=True):

        if FLAGS.method == 'SAdD':
            adt_logit = adt.forward_adv_drop(x_2,
                                             y,
                                             FLAGS.delta,
                                             is_training=True)
            additional_loss = L.ce_loss(adt_logit, y)
            loss = nll_loss + lamb * additional_loss
        elif FLAGS.method == 'VAdD-KL':
            logit_p = logit
            adt_logit = adt.forward_adv_drop(x_2,
                                             logit_p,
                                             FLAGS.delta,
                                             is_training=True,
                                             mode=FLAGS.method)
            additional_loss = L.kl_divergence_with_logit(logit_p, adt_logit)
            loss = nll_loss + lamb * additional_loss
        elif FLAGS.method == 'VAdD-QE':
            logit_p = logit
            adt_logit = adt.forward_adv_drop(x_2,
                                             logit_p,
                                             FLAGS.delta,
                                             is_training=True,
                                             mode=FLAGS.method)
            additional_loss = L.qe_loss(adt_logit, logit_p)
            loss = nll_loss + lamb * additional_loss
        elif FLAGS.method == 'VAT+VAdD-KL':
            logit_p = logit
            adt_logit = adt.forward_adv_drop(x_2,
                                             logit_p,
                                             FLAGS.delta,
                                             is_training=True,
                                             mode='VAdD-KL')
            additional_loss = L.kl_divergence_with_logit(logit_p, adt_logit)
            vat_loss = adt.virtual_adversarial_loss(x_1, logit_p)
            loss = nll_loss + lamb * additional_loss + vat_loss
        elif FLAGS.method == 'VAT+VAdD-QE':
            logit_p = logit
            adt_logit = adt.forward_adv_drop(x_2,
                                             logit_p,
                                             FLAGS.delta,
                                             is_training=True,
                                             mode='VAdD-QE')
            additional_loss = L.qe_loss(adt_logit, logit_p)
            vat_loss = adt.virtual_adversarial_loss(x_1, logit_p)
            loss = nll_loss + lamb * additional_loss + vat_loss
        elif FLAGS.method == 'VAT':
            logit_p = tf.stop_gradient(logit)
            logit_p = logit
            vat_loss = adt.virtual_adversarial_loss(x_1, logit_p)
            loss = nll_loss + vat_loss
        elif FLAGS.method == 'Pi':
            adt_logit = adt.forward(x_2)
            additional_loss = L.qe_loss(adt_logit, logit)
            loss = nll_loss + lamb * additional_loss
        elif FLAGS.method == 'baseline':
            additional_loss = 0
            adt_masks = masks
        else:
            raise NotImplementedError

    with tf.variable_scope(tf.get_variable_scope(), reuse=False):
        opt = tf.train.AdamOptimizer(learning_rate=lr, beta1=mom)
        tvars = tf.trainable_variables()
        grads_and_vars = opt.compute_gradients(loss, tvars)
        train_op = opt.apply_gradients(grads_and_vars, global_step=global_step)
    return loss, train_op, global_step