def build_training_graph(x, y, ul_x, lr, mom): global_step = tf.get_variable( name="global_step", shape=[], dtype=tf.float32, initializer=tf.constant_initializer(0.0), trainable=False, ) logit = vat.forward(x) nll_loss = L.ce_loss(logit, y) with tf.variable_scope(tf.get_variable_scope(), reuse=True): if FLAGS.method == 'vat': ul_logit = vat.forward(ul_x, is_training=True, update_batch_stats=False) vat_loss = vat.virtual_adversarial_loss(ul_x, ul_logit) additional_loss = vat_loss elif FLAGS.method == 'vatent': ul_logit = vat.forward(ul_x, is_training=True, update_batch_stats=False) vat_loss = vat.virtual_adversarial_loss(ul_x, ul_logit) ent_loss = L.entropy_y_x(ul_logit) additional_loss = vat_loss + ent_loss elif FLAGS.method == 'baseline': additional_loss = 0 else: raise NotImplementedError loss = nll_loss + additional_loss opt = tf.train.AdamOptimizer(learning_rate=lr, beta1=mom) tvars = tf.trainable_variables() grads_and_vars = opt.compute_gradients(loss, tvars) train_op = opt.apply_gradients(grads_and_vars, global_step=global_step) return loss, train_op, global_step
def build_training_graph(x_1, x_2, y, ul_x_1, ul_x_2, lr, mom, lamb): global_step = tf.get_variable( name="global_step", shape=[], dtype=tf.float32, initializer=tf.constant_initializer(0.0), trainable=False, ) logit = adt.forward(x_1, update_batch_stats=True) nll_loss = L.ce_loss(logit, y) with tf.variable_scope(tf.get_variable_scope(), reuse=True): if FLAGS.method == 'VAdD': ul_logit = adt.forward(ul_x_1) ent_loss = L.entropy_y_x(ul_logit) vadt_loss = adt.virtual_adversarial_dropout_loss(ul_x_2, ul_logit) loss = nll_loss + lamb * vadt_loss + ent_loss elif FLAGS.method == 'VAT': ul_logit = adt.forward(ul_x_1, update_batch_stats=False) ent_loss = L.entropy_y_x(ul_logit) vat_loss = adt.virtual_adversarial_loss(ul_x_1, ul_logit) loss = nll_loss + vat_loss + ent_loss elif FLAGS.method == 'Pi': ul_logit = adt.forward(ul_x_1, update_batch_stats=True) ul_adt_logit = adt.forward(ul_x_2, update_batch_stats=True) additional_loss = L.qe_loss(ul_logit, ul_adt_logit) #*4.0 ent_loss = L.entropy_y_x(ul_logit) loss = nll_loss + lamb * additional_loss + ent_loss else: raise NotImplementedError with tf.variable_scope(tf.get_variable_scope(), reuse=False): opt = tf.train.AdamOptimizer(learning_rate=lr, beta1=mom, beta2=0.999) tvars = tf.trainable_variables() grads_and_vars = opt.compute_gradients(loss, tvars) train_op = opt.apply_gradients(grads_and_vars, global_step=global_step) return loss, train_op, global_step