def build_training_graph(x, y, ul_x, lr, mom): global_step = tf.get_variable( name="global_step", shape=[], dtype=tf.float32, initializer=tf.constant_initializer(0.0), trainable=False, ) logit = vat.forward(x) nll_loss = L.ce_loss(logit, y) with tf.variable_scope(tf.get_variable_scope(), reuse=True): if FLAGS.method == 'vat': ul_logit = vat.forward(ul_x, is_training=True, update_batch_stats=False) vat_loss = vat.virtual_adversarial_loss(ul_x, ul_logit) additional_loss = vat_loss elif FLAGS.method == 'vatent': ul_logit = vat.forward(ul_x, is_training=True, update_batch_stats=False) vat_loss = vat.virtual_adversarial_loss(ul_x, ul_logit) ent_loss = L.entropy_y_x(ul_logit) additional_loss = vat_loss + ent_loss elif FLAGS.method == 'baseline': additional_loss = 0 else: raise NotImplementedError loss = nll_loss + additional_loss opt = tf.train.AdamOptimizer(learning_rate=lr, beta1=mom) tvars = tf.trainable_variables() grads_and_vars = opt.compute_gradients(loss, tvars) train_op = opt.apply_gradients(grads_and_vars, global_step=global_step) return loss, train_op, global_step
def build_training_graph(x, y, ul_x, lr, mom): logit = vat.forward(x) nll_loss = L.ce_loss(logit, y) x_reconst = tf.constant(0) if FLAGS.method == 'vat': ul_logit = vat.forward(ul_x, is_training=True, update_batch_stats=False) vat_loss, r_adv = vat.virtual_adversarial_loss(ul_x, ul_logit) x_adv = ul_x + r_adv additional_loss = vat_loss elif FLAGS.method == 'vatent': ul_logit = vat.forward(ul_x, is_training=True, update_batch_stats=False) vat_loss, r_adv = vat.virtual_adversarial_loss(ul_x, ul_logit) x_adv = ul_x + r_adv ent_loss = L.entropy_y_x(ul_logit) additional_loss = vat_loss + ent_loss elif FLAGS.method == 'lvat': ul_logit = vat.forward(ul_x, is_training=True, update_batch_stats=False) m_ae = get_ae() with tf.variable_scope(SCOPE_ENCODER ): if FLAGS.ae_type == 'VAE': _,z,_ = m_ae.encoder(ul_x, is_train=False) elif FLAGS.ae_type == 'AE': z = m_ae.encoder(ul_x, is_train=False) elif FLAGS.ae_type == 'Glow': print('[DEBUG] ... building Glow encoder') with tf.variable_scope('encoder' ): y, logdet, z = m_ae.encoder(ul_x) decoder = m_ae.decoder if FLAGS.ae_type == 'Glow': print('[DEBUG] ... building Glow VAT loss function') vat_loss, r_adv_y, r_adv_z = vat.virtual_adversarial_loss_glow((y, logdet, z), ul_logit, decoder) print('[DEBUG] ... building Glow decoder') with tf.variable_scope(SCOPE_DECODER, reuse=tf.AUTO_REUSE): #with tf.variable_scope('decoder' ): x_adv = decoder((y+r_adv_y, logdet, z+r_adv_z)) x_reconst = decoder((y, logdet, z)) else: vat_loss, r_adv = vat.virtual_adversarial_loss(z, ul_logit, decoder) with tf.variable_scope(SCOPE_DECODER, reuse=tf.AUTO_REUSE): x_adv = decoder(z + r_adv, False) x_reconst = decoder(z, False) additional_loss = vat_loss elif FLAGS.method == 'baseline': additional_loss = 0 else: raise NotImplementedError optimizer = tf.train.AdamOptimizer(learning_rate=lr, beta1=mom) theta_classifier = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=SCOPE_CLASSIFIER)
def test(epoch): global best_acc net.eval() test_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(val_loader): if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() optimizer.zero_grad() outputs = net.forward(inputs) dll_loss = nn.CrossEntropyLoss()(outputs, targets) if args.training == 'supervised': additional_loss = 0 elif args.training == 'vat': vat_loss = vat.virtual_adversarial_loss(inputs, outputs, use_gpu=use_cuda) additional_loss = vat_loss loss = dll_loss + additional_loss test_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() test_acc = 100 * float(correct) / total progress_bar( batch_idx, len(val_loader) + 1, 'Loss: %.5f | Acc: %.5f%% (%d/%d)' % (test_loss / (batch_idx + 1), 100 * float(correct) / total, correct, total)) # Tensorboard logging info = {'test_loss': test_loss, 'test_accuracy': test_acc} for tag, value in info.items(): logger.scalar_summary(tag, value, batch_idx + 1) # Save checkpoint. acc = 100. * correct / total if acc > best_acc: print('Saving..') state = { 'net': net.module if use_cuda else net, 'acc': acc, 'epoch': epoch, } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save( state, './checkpoint/ckpt_{0}_{1}.t7'.format(args.arch, args.numlabels)) best_acc = acc
def build_eval_graph(x, y, ul_x): losses = {} logit = vat.forward(x, is_training=False, update_batch_stats=False) nll_loss = L.ce_loss(logit, y) losses['NLL'] = nll_loss acc = L.accuracy(logit, y) losses['Acc'] = acc scope = tf.get_variable_scope() scope.reuse_variables() at_loss = vat.adversarial_loss(x, y, nll_loss, is_training=True) losses['AT_loss'] = at_loss ul_logit = vat.forward(ul_x, is_training=False, update_batch_stats=False) vat_loss = vat.virtual_adversarial_loss(ul_x, ul_logit, is_training=False) losses['VAT_loss'] = vat_loss return losses
def train(epoch): print('\nEpoch: %d' % epoch) net.train() train_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(train_loader): if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() optimizer.zero_grad() outputs = net.forward(inputs) dll_loss = nn.CrossEntropyLoss()(outputs, targets) if args.training == 'supervised': additional_loss = 0 elif args.training == 'vat': vat_loss = vat.virtual_adversarial_loss(inputs, outputs, use_gpu=use_cuda) additional_loss = vat_loss loss = dll_loss + additional_loss loss.backward() optimizer.step() train_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() train_acc = 100 * float(correct) / total progress_bar( batch_idx, len(train_loader) + 1, 'Loss: %.5f | Acc: %.5f%% (%d/%d)' % (train_loss / (batch_idx + 1), 100 * float(correct) / total, correct, total)) info = {'train_loss': train_loss, 'train_accuracy': train_acc} for tag, value in info.items(): logger.scalar_summary(tag, value, batch_idx + 1)
def build_eval_graph(x, y, ul_x): losses = {} logit = vat.forward(x, is_training=False, update_batch_stats=False) nll_loss = L.ce_loss(logit, y) losses['NLL'] = nll_loss acc = L.accuracy(logit, y) losses['Acc'] = acc scope = tf.get_variable_scope() scope.reuse_variables() results = {} if FLAGS.method == 'vat' or FLAGS.method == 'vatent': ul_logit = vat.forward(ul_x, is_training=False, update_batch_stats=False) vat_loss, r_adv = vat.virtual_adversarial_loss(ul_x, ul_logit, is_training=False) losses['VAT_loss'] = vat_loss x_adv = ul_x + r_adv x_reconst = ul_x # dummy for compatible y_reconst = tf.argmax(ul_logit, 1) # dummy for compatible elif FLAGS.method == 'lvat': ul_logit = vat.forward(ul_x, is_training=False, update_batch_stats=False) m_ae = get_ae() decoder = m_ae.decoder if FLAGS.ae_type == 'Glow': print('[DEBUG] ... building Glow encoder in eval graph') with tf.variable_scope(SCOPE_ENCODER, reuse=tf.AUTO_REUSE ): with tf.variable_scope('encoder' ): y_latent, logdet, z = m_ae.encoder(ul_x) lvat_loss, r_adv_y, r_adv_z = vat.virtual_adversarial_loss_glow((y_latent, logdet, z), ul_logit, decoder) print('[DEBUG] ... building Glow decoder in eval graph') with tf.variable_scope(SCOPE_DECODER, reuse=tf.AUTO_REUSE): with tf.variable_scope('decoder' ): x_adv = decoder((y_latent+r_adv_y, logdet, z+r_adv_z)) x_reconst = decoder((y_latent , logdet, z)) else: with tf.variable_scope(SCOPE_ENCODER, reuse=tf.AUTO_REUSE ): if FLAGS.ae_type == 'VAE': _,z,_ = m_ae.encoder(ul_x, is_train=False) elif FLAGS.ae_type == 'AE': z = m_ae.encoder(ul_x, is_train=False) lvat_loss, r_adv = vat.virtual_adversarial_loss(z, ul_logit, decoder) with tf.variable_scope(SCOPE_DECODER, reuse=tf.AUTO_REUSE): x_adv = decoder(z + r_adv, False) x_reconst = decoder(z, False) losses['LVAT_loss'] = lvat_loss logit_reconst = vat.forward(x_reconst, is_training=False, update_batch_stats=False) y_reconst = tf.argmax(logit_reconst, 1) results['x'] = ul_x results['x_reconst'] = x_reconst results['y_reconst'] = y_reconst results['x_adv'] = x_adv results['y_pred'] = tf.argmax(logit, 1) results['y_true'] = tf.argmax(y, 1) x = tf.reshape(x, (-1, FLAGS.img_size*FLAGS.img_size*3)) x_adv = tf.reshape(x_adv, (-1, FLAGS.img_size*FLAGS.img_size*3)) x_reconst = tf.reshape(x_reconst, (-1, FLAGS.img_size*FLAGS.img_size*3)) results['x_diff'] = tf.norm( x - x_reconst, axis=1) results['x_diff_adv'] = tf.norm( x - x_adv, axis=1) return losses, results
def train(input_t, output_map, alpha, max_it, root, batch_size, is_training, id, use_vat, use_pseudo_labels, use_mean_teacher, dataset): """ :param input_t: input tensor :param output_map: output layer of the network :param alpha: placeholder for leaky relu :param max_it: maximum training iterations :param root: base directory that contains the images :param batch_size: batch size :param is_training: toggle training :param id: GPU id :param use_vat: Enable VAT :param use_pseudo_labels: Use pseudo labels :param use_mean_teacher: Use mean teacher :param dataset: Choose dataset :return: """ h = 256 if dataset == "ENDOVIS" else 288 w = 320 if dataset == "ENDOVIS" else 384 num_parts = 5 if dataset == "ENDOVIS" else 4 num_connections = 4 if dataset == "ENDOVIS" else 0 # GPU Config gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=.95) # Set up placeholders y = tf.placeholder(tf.float32, shape=[None, h, w, num_parts + num_connections]) lr = tf.placeholder(tf.float32) loss_mask = tf.placeholder(tf.float32, shape=[batch_size]) # Loss if not use_mean_teacher: avr_loss = tf.losses.mean_squared_error(y, output_map, weights=tf.reshape(loss_mask, [batch_size, 1, 1, 1])) if use_mean_teacher: ema = tf.train.ExponentialMovingAverage(decay=.95) def ema_getter(getter, name, *args, **kwargs): var = getter(name, *args, **kwargs) ema_var = ema.average(var) return ema_var if ema_var else var tf.get_variable_scope().set_custom_getter(ema_getter) model_vars = tf.trainable_variables() output_student = output_map tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema.apply(model_vars)) output_teacher, _ = unet(input_t, .9 if dataset == "RMIT" else .7, 3, num_parts + num_connections, is_training=is_training, features_root=64, alpha=alpha) output_teacher = tf.stop_gradient(output_teacher) avr_loss = batch_size / tf.reduce_sum(loss_mask) * \ tf.losses.mean_squared_error(y, output_student, weights=tf.reshape(loss_mask, [batch_size, 1, 1, 1])) m = tf.placeholder(tf.float32, shape=[]) avr_loss = avr_loss + m * .1 * tf.losses.mean_squared_error(output_teacher, output_student) if use_vat: avr_loss = batch_size / tf.reduce_sum(loss_mask) * avr_loss + \ virtual_adversarial_loss(input_t, y, is_training=is_training, alpha=alpha) # Adam solver with tf.variable_scope("Adam", reuse=tf.AUTO_REUSE): opt = tf.train.AdamOptimizer(lr).minimize(avr_loss) # Start session and initialize weights sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True, log_device_placement=True)) sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(max_to_keep=10000) b_train = Batch(root, batch_size, dataset="ENDOVIS", include_unlabelled=use_vat or use_mean_teacher or use_tvm, pseudo_label=use_pseudo_labels) b_test = Batch(root, batch_size, dataset="ENDOVIS", include_unlabelled=False, testing=True, augment=False, train_postprocessing=False) current_lr = 1e-3 print("Chosen lr:", current_lr) # if model_dir is not None: # restore_op, restore_dict = tf.contrib.framework.assign_from_checkpoint( # model_dir + "/model.ckpt", # tf.contrib.slim.get_variables_to_restore(), # ignore_missing_vars=True # ) # sess.run(restore_op, feed_dict=restore_dict) # print("Restored session") # save graph writer = tf.summary.FileWriter(logdir='logdir', graph=sess.graph) writer.flush() if use_vat: test_interval = 250 else: test_interval = 200 def sigmoid_schedule(global_step, warm_up_steps=20000): if global_step > warm_up_steps: return 1. return np.exp(-5. * (1. - (global_step / warm_up_steps)) ** 2) for i in range(max_it): imgs, targets, _, mask = b_train.get_batch() current_loss, net_out, _ = sess.run( [avr_loss, output_map, opt], feed_dict={input_t: imgs, y: targets, lr: current_lr, is_training: True, alpha: 1 / np.random.uniform(low=3, high=8), loss_mask: mask, m: sigmoid_schedule(i) } ) if i % 100 == 0: print("Current regression loss:", current_loss.sum()) loc_pred = [] loc_true = [] for ch in range(num_parts): if b_train.batch_instrument_count[0] == 1: _, _, _, m_loc1 = cv2.minMaxLoc(net_out[0, :, :, ch]) loc_pred.append(m_loc1) _, _, _, m_loc2 = cv2.minMaxLoc(targets[0][:, :, ch]) loc_true.append(m_loc2) else: pass print("For the first sample-> Predicted: {} Ground Truth: {}\n".format(loc_pred, loc_true)) # save model for evaluation if i % test_interval == 0 and i != 0: print("Testing at iteration", i, "...") dir2save = os.path.join("tmp" + str(i), "model.ckpt") save_path = saver.save(sess, dir2save) print("Saved model to", save_path) sess.close()