コード例 #1
0
ファイル: train_semisup.py プロジェクト: hanwgyu/vat_tf
def build_training_graph(x, y, ul_x, lr, mom):
    global_step = tf.get_variable(
        name="global_step",
        shape=[],
        dtype=tf.float32,
        initializer=tf.constant_initializer(0.0),
        trainable=False,
    )
    logit = vat.forward(x)
    nll_loss = L.ce_loss(logit, y)
    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        if FLAGS.method == 'vat':
            ul_logit = vat.forward(ul_x,
                                   is_training=True,
                                   update_batch_stats=False)
            vat_loss = vat.virtual_adversarial_loss(ul_x, ul_logit)
            additional_loss = vat_loss
        elif FLAGS.method == 'vatent':
            ul_logit = vat.forward(ul_x,
                                   is_training=True,
                                   update_batch_stats=False)
            vat_loss = vat.virtual_adversarial_loss(ul_x, ul_logit)
            ent_loss = L.entropy_y_x(ul_logit)
            additional_loss = vat_loss + ent_loss
        elif FLAGS.method == 'baseline':
            additional_loss = 0
        else:
            raise NotImplementedError
        loss = nll_loss + additional_loss

    opt = tf.train.AdamOptimizer(learning_rate=lr, beta1=mom)
    tvars = tf.trainable_variables()
    grads_and_vars = opt.compute_gradients(loss, tvars)
    train_op = opt.apply_gradients(grads_and_vars, global_step=global_step)
    return loss, train_op, global_step
コード例 #2
0
ファイル: train_semisup.py プロジェクト: geosada/LVAT
def build_training_graph(x, y, ul_x, lr, mom):

    logit = vat.forward(x)

    nll_loss = L.ce_loss(logit, y)
    x_reconst = tf.constant(0)
    if FLAGS.method == 'vat':
        ul_logit = vat.forward(ul_x, is_training=True, update_batch_stats=False)
        vat_loss, r_adv = vat.virtual_adversarial_loss(ul_x, ul_logit)
        x_adv = ul_x + r_adv
        additional_loss = vat_loss

    elif FLAGS.method == 'vatent':
        ul_logit = vat.forward(ul_x, is_training=True, update_batch_stats=False)
        vat_loss, r_adv = vat.virtual_adversarial_loss(ul_x, ul_logit)
        x_adv = ul_x + r_adv
        ent_loss = L.entropy_y_x(ul_logit)
        additional_loss = vat_loss + ent_loss

    elif FLAGS.method == 'lvat':
        ul_logit = vat.forward(ul_x, is_training=True, update_batch_stats=False)
        
        m_ae = get_ae()
        with tf.variable_scope(SCOPE_ENCODER ):
            if FLAGS.ae_type == 'VAE':
                _,z,_ = m_ae.encoder(ul_x, is_train=False)
            elif FLAGS.ae_type == 'AE':
                z = m_ae.encoder(ul_x, is_train=False)
            elif FLAGS.ae_type == 'Glow':
                print('[DEBUG] ... building Glow encoder')
                with tf.variable_scope('encoder' ):
                    y, logdet, z = m_ae.encoder(ul_x)

        decoder = m_ae.decoder
        if FLAGS.ae_type == 'Glow':
            print('[DEBUG] ... building Glow VAT loss function')
            vat_loss, r_adv_y, r_adv_z = vat.virtual_adversarial_loss_glow((y, logdet, z), ul_logit, decoder)

            print('[DEBUG] ... building Glow decoder')
            with tf.variable_scope(SCOPE_DECODER, reuse=tf.AUTO_REUSE):
                #with tf.variable_scope('decoder' ):
                    x_adv     = decoder((y+r_adv_y, logdet, z+r_adv_z))
                    x_reconst = decoder((y,         logdet, z))

        else:
            vat_loss, r_adv = vat.virtual_adversarial_loss(z, ul_logit, decoder)

            with tf.variable_scope(SCOPE_DECODER, reuse=tf.AUTO_REUSE):
                x_adv     = decoder(z + r_adv, False)
                x_reconst = decoder(z, False)

        additional_loss = vat_loss

    elif FLAGS.method == 'baseline':
        additional_loss = 0
    else:
        raise NotImplementedError

    optimizer = tf.train.AdamOptimizer(learning_rate=lr, beta1=mom) 
    theta_classifier = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=SCOPE_CLASSIFIER)
コード例 #3
0
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(val_loader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()

        outputs = net.forward(inputs)

        dll_loss = nn.CrossEntropyLoss()(outputs, targets)

        if args.training == 'supervised':
            additional_loss = 0
        elif args.training == 'vat':
            vat_loss = vat.virtual_adversarial_loss(inputs,
                                                    outputs,
                                                    use_gpu=use_cuda)
            additional_loss = vat_loss

        loss = dll_loss + additional_loss

        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()
        test_acc = 100 * float(correct) / total

        progress_bar(
            batch_idx,
            len(val_loader) + 1, 'Loss: %.5f | Acc: %.5f%% (%d/%d)' %
            (test_loss /
             (batch_idx + 1), 100 * float(correct) / total, correct, total))

        # Tensorboard logging
        info = {'test_loss': test_loss, 'test_accuracy': test_acc}

        for tag, value in info.items():
            logger.scalar_summary(tag, value, batch_idx + 1)

    # Save checkpoint.
    acc = 100. * correct / total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.module if use_cuda else net,
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(
            state,
            './checkpoint/ckpt_{0}_{1}.t7'.format(args.arch, args.numlabels))
        best_acc = acc
コード例 #4
0
ファイル: train_semisup.py プロジェクト: hanwgyu/vat_tf
def build_eval_graph(x, y, ul_x):
    losses = {}
    logit = vat.forward(x, is_training=False, update_batch_stats=False)
    nll_loss = L.ce_loss(logit, y)
    losses['NLL'] = nll_loss
    acc = L.accuracy(logit, y)
    losses['Acc'] = acc
    scope = tf.get_variable_scope()
    scope.reuse_variables()
    at_loss = vat.adversarial_loss(x, y, nll_loss, is_training=True)
    losses['AT_loss'] = at_loss
    ul_logit = vat.forward(ul_x, is_training=False, update_batch_stats=False)
    vat_loss = vat.virtual_adversarial_loss(ul_x, ul_logit, is_training=False)
    losses['VAT_loss'] = vat_loss
    return losses
コード例 #5
0
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        optimizer.zero_grad()
        outputs = net.forward(inputs)

        dll_loss = nn.CrossEntropyLoss()(outputs, targets)

        if args.training == 'supervised':
            additional_loss = 0
        elif args.training == 'vat':
            vat_loss = vat.virtual_adversarial_loss(inputs,
                                                    outputs,
                                                    use_gpu=use_cuda)
            additional_loss = vat_loss

        loss = dll_loss + additional_loss

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()
        train_acc = 100 * float(correct) / total

        progress_bar(
            batch_idx,
            len(train_loader) + 1, 'Loss: %.5f | Acc: %.5f%% (%d/%d)' %
            (train_loss /
             (batch_idx + 1), 100 * float(correct) / total, correct, total))

        info = {'train_loss': train_loss, 'train_accuracy': train_acc}

        for tag, value in info.items():
            logger.scalar_summary(tag, value, batch_idx + 1)
コード例 #6
0
ファイル: train_semisup.py プロジェクト: geosada/LVAT
def build_eval_graph(x, y, ul_x):
    losses = {}
    logit = vat.forward(x, is_training=False, update_batch_stats=False)
    nll_loss = L.ce_loss(logit, y)
    losses['NLL'] = nll_loss
    acc = L.accuracy(logit, y)
    losses['Acc'] = acc
    scope = tf.get_variable_scope()
    scope.reuse_variables()

    results = {}
    if FLAGS.method == 'vat' or FLAGS.method == 'vatent':
        ul_logit = vat.forward(ul_x, is_training=False, update_batch_stats=False)
        vat_loss, r_adv = vat.virtual_adversarial_loss(ul_x, ul_logit, is_training=False)
        losses['VAT_loss'] = vat_loss
        x_adv = ul_x + r_adv
        x_reconst = ul_x    # dummy for compatible
        y_reconst = tf.argmax(ul_logit, 1)       # dummy for compatible

    elif FLAGS.method == 'lvat':
        ul_logit = vat.forward(ul_x, is_training=False, update_batch_stats=False)

        m_ae = get_ae()
        decoder = m_ae.decoder
        if FLAGS.ae_type == 'Glow':
            print('[DEBUG] ... building Glow encoder in eval graph')
            with tf.variable_scope(SCOPE_ENCODER, reuse=tf.AUTO_REUSE ):
                with tf.variable_scope('encoder' ):
                    y_latent, logdet, z = m_ae.encoder(ul_x)
            lvat_loss, r_adv_y, r_adv_z = vat.virtual_adversarial_loss_glow((y_latent, logdet, z), ul_logit, decoder)
            print('[DEBUG] ... building Glow decoder in eval graph')
            with tf.variable_scope(SCOPE_DECODER, reuse=tf.AUTO_REUSE):
                with tf.variable_scope('decoder' ):
                    x_adv     = decoder((y_latent+r_adv_y, logdet, z+r_adv_z))
                    x_reconst = decoder((y_latent        , logdet, z))

        else:
            with tf.variable_scope(SCOPE_ENCODER, reuse=tf.AUTO_REUSE ):
                if FLAGS.ae_type == 'VAE':
                    _,z,_ = m_ae.encoder(ul_x, is_train=False)
                elif FLAGS.ae_type == 'AE':
                    z = m_ae.encoder(ul_x, is_train=False)
            lvat_loss, r_adv = vat.virtual_adversarial_loss(z, ul_logit, decoder)
            with tf.variable_scope(SCOPE_DECODER, reuse=tf.AUTO_REUSE):
                x_adv     = decoder(z + r_adv, False)
                x_reconst = decoder(z, False)

        losses['LVAT_loss'] = lvat_loss

        logit_reconst = vat.forward(x_reconst, is_training=False, update_batch_stats=False)
        y_reconst = tf.argmax(logit_reconst, 1)

    results['x']         = ul_x
    results['x_reconst'] = x_reconst
    results['y_reconst'] = y_reconst

    results['x_adv'] = x_adv
    results['y_pred'] = tf.argmax(logit, 1)
    results['y_true'] = tf.argmax(y, 1)

    x = tf.reshape(x, (-1, FLAGS.img_size*FLAGS.img_size*3))
    x_adv = tf.reshape(x_adv, (-1, FLAGS.img_size*FLAGS.img_size*3))
    x_reconst = tf.reshape(x_reconst, (-1, FLAGS.img_size*FLAGS.img_size*3))
    results['x_diff'] = tf.norm( x - x_reconst, axis=1)
    results['x_diff_adv'] = tf.norm( x - x_adv, axis=1)

    return losses, results
コード例 #7
0
def train(input_t, output_map, alpha, max_it, root, batch_size, is_training, id, use_vat, use_pseudo_labels,
          use_mean_teacher, dataset):
    """
    :param input_t: input tensor
    :param output_map: output layer of the network
    :param alpha: placeholder for leaky relu
    :param max_it: maximum training iterations
    :param root: base directory that contains the images
    :param batch_size: batch size
    :param is_training: toggle training
    :param id: GPU id
    :param use_vat: Enable VAT
    :param use_pseudo_labels: Use pseudo labels
    :param use_mean_teacher: Use mean teacher
    :param dataset: Choose dataset
    :return:
    """

    h = 256 if dataset == "ENDOVIS" else 288
    w = 320 if dataset == "ENDOVIS" else 384
    num_parts = 5 if dataset == "ENDOVIS" else 4
    num_connections = 4 if dataset == "ENDOVIS" else 0

    # GPU Config
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=.95)

    # Set up placeholders
    y = tf.placeholder(tf.float32, shape=[None, h, w, num_parts + num_connections])
    lr = tf.placeholder(tf.float32)
    loss_mask = tf.placeholder(tf.float32, shape=[batch_size])

    # Loss
    if not use_mean_teacher:
        avr_loss = tf.losses.mean_squared_error(y, output_map,
                                                weights=tf.reshape(loss_mask,
                                                                   [batch_size, 1, 1, 1]))
    if use_mean_teacher:
        ema = tf.train.ExponentialMovingAverage(decay=.95)

        def ema_getter(getter, name, *args, **kwargs):
            var = getter(name, *args, **kwargs)
            ema_var = ema.average(var)
            return ema_var if ema_var else var

        tf.get_variable_scope().set_custom_getter(ema_getter)
        model_vars = tf.trainable_variables()
        output_student = output_map
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema.apply(model_vars))
        output_teacher, _ = unet(input_t, .9 if dataset == "RMIT" else .7, 3,
                                 num_parts + num_connections,
                                 is_training=is_training,
                                 features_root=64,
                                 alpha=alpha)
        output_teacher = tf.stop_gradient(output_teacher)
        avr_loss = batch_size / tf.reduce_sum(loss_mask) * \
                   tf.losses.mean_squared_error(y, output_student,
                                                weights=tf.reshape(loss_mask,
                                                                   [batch_size, 1, 1, 1]))
        m = tf.placeholder(tf.float32, shape=[])
        avr_loss = avr_loss + m * .1 * tf.losses.mean_squared_error(output_teacher, output_student)

    if use_vat:
        avr_loss = batch_size / tf.reduce_sum(loss_mask) * avr_loss + \
                   virtual_adversarial_loss(input_t, y, is_training=is_training, alpha=alpha)

    # Adam solver
    with tf.variable_scope("Adam", reuse=tf.AUTO_REUSE):
        opt = tf.train.AdamOptimizer(lr).minimize(avr_loss)

    # Start session and initialize weights
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                            allow_soft_placement=True,
                                            log_device_placement=True))
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(max_to_keep=10000)

    b_train = Batch(root, batch_size, dataset="ENDOVIS",
                    include_unlabelled=use_vat or use_mean_teacher or use_tvm,
                    pseudo_label=use_pseudo_labels)
    b_test = Batch(root, batch_size, dataset="ENDOVIS", include_unlabelled=False, testing=True, augment=False,
                   train_postprocessing=False)

    current_lr = 1e-3 
    print("Chosen lr:", current_lr)

    # if model_dir is not None:
    #     restore_op, restore_dict = tf.contrib.framework.assign_from_checkpoint(
    #         model_dir + "/model.ckpt",
    #         tf.contrib.slim.get_variables_to_restore(),
    #         ignore_missing_vars=True
    #     )
    #     sess.run(restore_op, feed_dict=restore_dict)
    #     print("Restored session")

    # save graph
    writer = tf.summary.FileWriter(logdir='logdir', graph=sess.graph)
    writer.flush()

    if use_vat:
        test_interval = 250
    else:
        test_interval = 200

    def sigmoid_schedule(global_step, warm_up_steps=20000):
        if global_step > warm_up_steps:
            return 1.

        return np.exp(-5. * (1. - (global_step / warm_up_steps)) ** 2)

    for i in range(max_it):

        imgs, targets, _, mask = b_train.get_batch()

        current_loss, net_out, _ = sess.run(
            [avr_loss, output_map, opt],
            feed_dict={input_t: imgs,
                       y: targets,
                       lr: current_lr,
                       is_training: True,
                       alpha: 1 / np.random.uniform(low=3, high=8),
                       loss_mask: mask,
                       m: sigmoid_schedule(i)
                       }
        )

        if i % 100 == 0:
            print("Current regression loss:", current_loss.sum())
            loc_pred = []
            loc_true = []
            for ch in range(num_parts):
                if b_train.batch_instrument_count[0] == 1:
                    _, _, _, m_loc1 = cv2.minMaxLoc(net_out[0, :, :, ch])
                    loc_pred.append(m_loc1)
                    _, _, _, m_loc2 = cv2.minMaxLoc(targets[0][:, :, ch])
                    loc_true.append(m_loc2)
                else:
                    pass

            print("For the first sample-> Predicted: {}    Ground Truth: {}\n".format(loc_pred, loc_true))

        # save model for evaluation
        if i % test_interval == 0 and i != 0:

            print("Testing at iteration", i, "...")
            dir2save = os.path.join("tmp" + str(i), "model.ckpt")
            save_path = saver.save(sess, dir2save)
            print("Saved model to", save_path)

    sess.close()