Exemple #1
0
def main(_):
    train_images, train_labels = mnist_tools.get_data('train')
    test_images, test_labels = mnist_tools.get_data('test')

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else np.random.randint(
        0, 1000)
    print('Seed:', seed)
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, NUM_LABELS,
                                           seed)

    # add twos, fours, and sixes
    if FLAGS.testadd:
        num_to_add = 10
        for i in range(10):
            items = np.where(train_labels == i)[0]
            inds = np.random.choice(len(items), num_to_add, replace=False)
            sup_by_label[i] = np.vstack(
                [sup_by_label[i], train_images[items[inds]]])

    add_random_samples = 0
    if add_random_samples > 0:
        rng = np.random.RandomState()
        indices = rng.choice(len(train_images), add_random_samples, False)

        for i in indices:
            l = train_labels[i]
            sup_by_label[l] = np.vstack([sup_by_label[l], [train_images[i]]])
            print(l)

    graph = tf.Graph()
    with graph.as_default():
        model = semisup.SemisupModel(semisup.architectures.mnist_model,
                                     NUM_LABELS,
                                     IMAGE_SHAPE,
                                     dropout_keep_prob=0.8)

        # Set up inputs.
        if FLAGS.random_batches:
            sup_lbls = np.asarray(
                np.hstack([
                    np.ones(len(i)) * ind for ind, i in enumerate(sup_by_label)
                ]), np.int)
            sup_images = np.vstack(sup_by_label)
            t_sup_images, t_sup_labels = semisup.create_input(
                sup_images, sup_lbls, FLAGS.sup_per_batch * NUM_LABELS)
        else:
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        if FLAGS.semisup:
            if FLAGS.equal_cls_unsup:
                allimgs_bylabel = semisup.sample_by_label(
                    train_images, train_labels, 5000, NUM_LABELS, seed)
                t_unsup_images, _ = semisup.create_per_class_inputs(
                    allimgs_bylabel, FLAGS.sup_per_batch)
            else:
                t_unsup_images, _ = semisup.create_input(
                    train_images, train_labels, FLAGS.unsup_batch_size)

            t_unsup_emb = model.image_to_embedding(t_unsup_images)
            model.add_semisup_loss(t_sup_emb,
                                   t_unsup_emb,
                                   t_sup_labels,
                                   walker_weight=FLAGS.walker_weight,
                                   visit_weight=FLAGS.visit_weight,
                                   proximity_weight=FLAGS.proximity_weight)
        model.add_logit_loss(t_sup_logit, t_sup_labels)

        t_learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                                     model.step,
                                                     FLAGS.decay_steps,
                                                     FLAGS.decay_factor,
                                                     staircase=True)
        train_op = model.create_train_op(t_learning_rate)
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for step in range(FLAGS.max_steps):
            _, summaries = sess.run([train_op, summary_op])
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

                saver.save(sess, FLAGS.logdir, model.step)

        coord.request_stop()
        coord.join(threads)
def main(_):
    train_images, train_labels = mnist_tools.get_data('train')
    test_images, test_labels = mnist_tools.get_data('test')

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, NUM_LABELS,
                                           seed)

    graph = tf.Graph()
    with graph.as_default():
        model = semisup.SemisupModel(semisup.architectures.mnist_model,
                                     NUM_LABELS, IMAGE_SHAPE)

        # Set up inputs.
        t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                                 FLAGS.unsup_batch_size)
        t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
            sup_by_label, FLAGS.sup_per_batch)

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_unsup_emb = model.image_to_embedding(t_unsup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        model.add_semisup_loss(t_sup_emb,
                               t_unsup_emb,
                               t_sup_labels,
                               visit_weight=FLAGS.visit_weight)
        model.add_logit_loss(t_sup_logit, t_sup_labels)

        t_learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                                     model.step,
                                                     FLAGS.decay_steps,
                                                     FLAGS.decay_factor,
                                                     staircase=True)
        train_op = model.create_train_op(t_learning_rate)
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for step in xrange(FLAGS.max_steps):
            _, summaries = sess.run([train_op, summary_op])
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

                saver.save(sess, FLAGS.logdir, model.step)

        coord.request_stop()
        coord.join(threads)
Exemple #3
0
def main(_):
    # Get dataset-related toolbox.
    dataset_tools = import_module('semisup.tools.' + FLAGS.dataset)
    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE
    test_images, test_labels = dataset_tools.get_data('test')
    print (test_images)

    graph = tf.Graph()
    with graph.as_default():

        # Set up input pipeline.
        #image, label = tf.train.slice_input_producer([test_images, test_labels])
        #images, labels = tf.train.batch(
        #    [image, label], batch_size=FLAGS.eval_batch_size)
        images, labels = semisup.create_input(test_images,test_labels,FLAGS.eval_batch_size)

        images = tf.cast(images, tf.float32)
        labels = tf.cast(labels, tf.int64)

        # Reshape if necessary.
        if FLAGS.new_size > 0:
            new_shape = [FLAGS.new_size, FLAGS.new_size, 3]
        else:
            new_shape = None

        if FLAGS.augmentation:
            # TODO(haeusser) generalize augmentation
            def _random_invert(inputs1, _):
                inputs = tf.cast(inputs1, tf.float32)
                inputs = tf.image.adjust_brightness(inputs, tf.random_uniform((1, 1), 0.0, 0.5))
                inputs = tf.image.random_contrast(inputs, 0.3, 1)
                # inputs = tf.image.per_image_standardization(inputs)
                inputs = tf.image.random_hue(inputs, 0.05)
                inputs = tf.image.random_saturation(inputs, 0.5, 1.1)

                def f1(): return tf.abs(inputs)  # annotations

                def f2(): return tf.abs(inputs1)

                return tf.cond(tf.less(tf.random_uniform([], 0.0, 1), 0.5), f1, f2)

            augmentation_function = _random_invert
        else:
            augmentation_function = None

        # Create function that defines the network.
        model_function = partial(
            architecture,
            is_training=False,
            new_shape=new_shape,
            img_shape=image_shape,
            augmentation_function=augmentation_function,
            image_summary=False,
            emb_size=FLAGS.emb_size)


        # Set up semisup model.
        model = semisup.SemisupModel(
            model_function,
            num_labels,
            image_shape,
            test_in=images)

        # Add moving average variables.
        for var in tf.get_collection('moving_vars'):
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
        for var in slim.get_model_variables():
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)

        # Get prediction tensor from semisup model.
        predictions = tf.argmax(model.test_logit, 1)
        cmatrix = tf.confusion_matrix(labels,predictions,num_labels)
        # Accuracy metric for summaries.
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        })
        for name, value in names_to_values.iteritems():
            tf.summary.scalar(name, value)
        confusion_image = tf.reshape(tf.cast(cmatrix, tf.float32),
                                     [1,10, 10, 1])
        tf.summary.tensor_summary('cmatrix', cmatrix)
        tf.summary.image('confusion matrix',confusion_image)
        tf_heatmap = tfplot.wrap_axesplot(sns.heatmap, figsize=(7,5), cbar=False, annot=True,yticklabels=dataset_tools.CLASSES, cmap='jet')
        tf.summary.image("heat_maps", tf.reshape(tf_heatmap(cmatrix),[1,500,700,4]))
        # Run the actual evaluation loop.
        num_batches = math.ceil(len(test_labels) / float(FLAGS.eval_batch_size))

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.logdir + '/train',
            logdir=FLAGS.logdir + '/eval',
            num_evals=num_batches,
            eval_op=tf.Print(list(names_to_updates.values()),[confusion_image], message="cmatrix:", summarize=500),
            eval_interval_secs=FLAGS.eval_interval_secs,
            session_config=config,
            timeout=FLAGS.timeout,
            #hooks=[tf_debug.LocalCLIDebugHook(ui_type="readline")]
        )
Exemple #4
0
def main(argv):
    del argv

    # Load data.
    dataset_tools = import_module('tools.' + FLAGS.dataset)
    train_images, train_labels = dataset_tools.get_data('train')
    if FLAGS.target_dataset is not None:
        target_dataset_tools = import_module('tools.' + FLAGS.target_dataset)
        train_images_unlabeled, _ = target_dataset_tools.get_data(
            FLAGS.target_dataset_split)
    else:
        train_images_unlabeled, _ = dataset_tools.get_data('unlabeled')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled, None,
                                                  FLAGS.unsup_batch_size)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) generalize augmentation
                def _random_invert(inputs, _):
                    randu = tf.random_uniform(
                        shape=[FLAGS.sup_per_batch * num_labels],
                        minval=0.,
                        maxval=1.,
                        dtype=tf.float32)
                    randu = tf.cast(tf.less(randu, 0.5), tf.float32)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    inputs = tf.cast(inputs, tf.float32)
                    return tf.abs(inputs - 255 * randu)

                augmentation_function = _random_invert
            else:
                augmentation_function = None

            # Create function that defines the network.
            model_function = partial(
                architecture,
                new_shape=new_shape,
                img_shape=image_shape,
                augmentation_function=augmentation_function,
                batch_norm_decay=FLAGS.batch_norm_decay,
                emb_size=FLAGS.emb_size)

            # Set up semisup model.
            model = semisup.SemisupModel(model_function, num_labels,
                                         image_shape)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)
            t_unsup_emb = model.image_to_embedding(t_unsup_images)

            # Add virtual embeddings.
            if FLAGS.virtual_embeddings:
                t_sup_emb = tf.concat(0, [
                    t_sup_emb,
                    semisup.create_virt_emb(FLAGS.virtual_embeddings,
                                            FLAGS.emb_size)
                ])

                if not FLAGS.remove_classes:
                    # need to add additional labels for virtual embeddings
                    t_sup_labels = tf.concat(0, [
                        t_sup_labels,
                        (num_labels +
                         tf.range(1, FLAGS.virtual_embeddings + 1, tf.int64)) *
                        tf.ones([FLAGS.virtual_embeddings], tf.int64)
                    ])

            t_sup_logit = model.embedding_to_logit(t_sup_emb)

            # Add losses.
            visit_weight_envelope_steps = (FLAGS.walker_weight_envelope_steps
                                           if FLAGS.visit_weight_envelope_steps
                                           == -1 else
                                           FLAGS.visit_weight_envelope_steps)
            visit_weight_envelope_delay = (FLAGS.walker_weight_envelope_delay
                                           if FLAGS.visit_weight_envelope_delay
                                           == -1 else
                                           FLAGS.visit_weight_envelope_delay)
            visit_weight = apply_envelope(
                type=FLAGS.visit_weight_envelope,
                step=model.step,
                final_weight=FLAGS.visit_weight,
                growing_steps=visit_weight_envelope_steps,
                delay=visit_weight_envelope_delay)
            walker_weight = apply_envelope(
                type=FLAGS.walker_weight_envelope,
                step=model.step,
                final_weight=FLAGS.walker_weight,
                growing_steps=FLAGS.walker_weight_envelope_steps,  # pylint:disable=line-too-long
                delay=FLAGS.walker_weight_envelope_delay)
            tf.summary.scalar('Weights_Visit', visit_weight)
            tf.summary.scalar('Weights_Walker', walker_weight)

            if FLAGS.unsup_samples != 0:
                model.add_semisup_loss(t_sup_emb,
                                       t_unsup_emb,
                                       t_sup_labels,
                                       visit_weight=visit_weight,
                                       walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)

            # Set up learning rate
            t_learning_rate = tf.maximum(
                tf.train.exponential_decay(FLAGS.learning_rate,
                                           model.step,
                                           FLAGS.decay_steps,
                                           FLAGS.decay_factor,
                                           staircase=True),
                FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            train_op = model.create_train_op(t_learning_rate)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)  # pylint:disable=line-too-long

            slim.learning.train(
                train_op,
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
            )
Exemple #5
0
def main(_):
  FLAGS.emb_size = 128
  optimizer = 'adam'

  train_images, train_labels = mnist_tools.get_data('train')
  test_images, test_labels = mnist_tools.get_data('test')

  # Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else np.random.randint(0, 1000)
  print('Seed:', seed)

  sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                         FLAGS.sup_per_class, NUM_LABELS, seed)

  graph = tf.Graph()
  with graph.as_default():
    model_func = semisup.architectures.mnist_model
    if FLAGS.dropout_keep_prob < 1:
        model_func = semisup.architectures.mnist_model_dropout

    model = semisup.SemisupModel(model_func, NUM_LABELS,
                                 IMAGE_SHAPE, optimizer=optimizer, emb_size=FLAGS.emb_size,
                                 dropout_keep_prob=FLAGS.dropout_keep_prob)

    # Set up inputs.
    t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
      sup_by_label, FLAGS.sup_per_batch)

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(t_sup_images)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    t_unsup_images = tf.placeholder("float", shape=[None] + IMAGE_SHAPE)

    proximity_weight = tf.placeholder("float", shape=[])
    visit_weight = tf.placeholder("float", shape=[])
    walker_weight = tf.placeholder("float", shape=[])
    t_logit_weight = tf.placeholder("float", shape=[])
    t_l1_weight = tf.placeholder("float", shape=[])
    t_learning_rate = tf.placeholder("float", shape=[])

    t_unsup_emb = model.image_to_embedding(t_unsup_images)
    model.add_semisup_loss(
      t_sup_emb, t_unsup_emb, t_sup_labels,
      walker_weight=walker_weight, visit_weight=visit_weight, proximity_weight=proximity_weight)
    model.add_logit_loss(t_sup_logit, t_sup_labels, weight=t_logit_weight)

    model.add_emb_regularization(t_sup_emb, weight=t_l1_weight)
    model.add_emb_regularization(t_unsup_emb, weight=t_l1_weight)

    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()

  sess = tf.InteractiveSession(graph=graph)

  unsup_images_iterator = semisup.create_input(train_images, train_labels,
                                               FLAGS.unsup_batch_size)
  tf.global_variables_initializer().run()

  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)

  use_new_visit_loss = False
  for step in range(FLAGS.max_steps):
    unsup_images, _ = sess.run(unsup_images_iterator)

    if use_new_visit_loss:
      _, summaries, train_loss = sess.run([train_op, summary_op, model.train_loss], {
        t_unsup_images: unsup_images,
        walker_weight: FLAGS.walker_weight,
        proximity_weight: 0.3 + apply_envelope("lin", step, 0.7, FLAGS.warmup_steps, 0)
                          - apply_envelope("lin", step, FLAGS.visit_weight, 2000, FLAGS.warmup_steps),
        t_logit_weight: FLAGS.logit_weight,
        t_l1_weight: FLAGS.l1_weight,
        visit_weight: apply_envelope("lin", step, FLAGS.visit_weight, 2000, FLAGS.warmup_steps),
        t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
      })
    else:
      _, summaries, train_loss = sess.run([train_op, summary_op, model.train_loss], {
        t_unsup_images: unsup_images,
        walker_weight: FLAGS.walker_weight,
        visit_weight: 0.3 + apply_envelope("lin", step, 0.7, FLAGS.warmup_steps, 0),
        proximity_weight: 0,
        t_logit_weight: FLAGS.logit_weight,
        t_l1_weight: FLAGS.l1_weight,
        t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
      })

    if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
      print('Step: %d' % step)
      test_pred = model.classify(test_images).argmax(-1)
      conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
      test_err = (test_labels != test_pred).mean() * 100
      print(conf_mtx)
      print('Test error: %.2f %%' % test_err)
      print('Train loss: %.2f ' % train_loss)
      print()

      test_summary = tf.Summary(
        value=[tf.Summary.Value(
          tag='Test Err', simple_value=test_err)])

      summary_writer.add_summary(summaries, step)
      summary_writer.add_summary(test_summary, step)
def main(_):
    # Load data.
    dataset_tools = import_module('tools.' + FLAGS.dataset)
    train_images, train_labels = dataset_tools.get_data('train')
    if FLAGS.target_dataset is not None:
        target_dataset_tools = import_module('tools.' + FLAGS.target_dataset)
        train_images_unlabeled, _ = target_dataset_tools.get_data(
            FLAGS.target_dataset_split)
    else:
        train_images_unlabeled, _ = dataset_tools.get_data('unlabeled')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled, None,
                                                  FLAGS.unsup_batch_size)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) revert this to the general case
                def _random_invert(inputs, _):
                    randu = tf.random_uniform(
                        shape=[FLAGS.sup_per_batch * num_labels],
                        minval=0.,
                        maxval=1.,
                        dtype=tf.float32)
                    randu = tf.cast(tf.less(randu, 0.5), tf.float32)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    inputs = tf.cast(inputs, tf.float32)
                    return tf.abs(inputs - 255 * randu)

                augmentation_function = _random_invert

                # if hasattr(dataset_tools, 'augmentation_params'):
                #    augmentation_function = partial(
                #        apply_augmentation, params=dataset_tools.augmentation_params)
                # else:
                #    augmentation_function = apply_affine_augmentation
            else:
                augmentation_function = None

            # Create function that defines the network.
            model_function = partial(
                architecture,
                new_shape=new_shape,
                img_shape=image_shape,
                augmentation_function=augmentation_function,
                batch_norm_decay=FLAGS.batch_norm_decay,
                emb_size=FLAGS.emb_size)

            # Set up semisup model.
            model = semisup.SemisupModel(model_function, num_labels,
                                         image_shape)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)
            t_unsup_emb = model.image_to_embedding(t_unsup_images)

            # Add virtual embeddings.
            if FLAGS.virtual_embeddings:
                t_sup_emb = tf.concat(0, [
                    t_sup_emb,
                    semisup.create_virt_emb(FLAGS.virtual_embeddings,
                                            FLAGS.emb_size)
                ])

                if not FLAGS.remove_classes:
                    # need to add additional labels for virtual embeddings
                    t_sup_labels = tf.concat(0, [
                        t_sup_labels,
                        (num_labels +
                         tf.range(1, FLAGS.virtual_embeddings + 1, tf.int64)) *
                        tf.ones([FLAGS.virtual_embeddings], tf.int64)
                    ])

            t_sup_logit = model.embedding_to_logit(t_sup_emb)

            # Add losses.
            if FLAGS.mmd:
                sys.path.insert(0, '/usr/wiss/haeusser/libs/opt-mmd/gan')
                from mmd import mix_rbf_mmd2

                bandwidths = [2.0, 5.0, 10.0, 20.0, 40.0, 80.0]  # original

                t_sup_flat = tf.reshape(t_sup_emb,
                                        [FLAGS.sup_per_batch * num_labels, -1])
                t_unsup_flat = tf.reshape(t_unsup_emb,
                                          [FLAGS.unsup_batch_size, -1])
                mmd_loss = mix_rbf_mmd2(t_sup_flat,
                                        t_unsup_flat,
                                        sigmas=bandwidths)
                tf.losses.add_loss(mmd_loss)
                tf.summary.scalar('MMD_loss', mmd_loss)
            else:
                visit_weight_envelope_steps = FLAGS.walker_weight_envelope_steps if FLAGS.visit_weight_envelope_steps == -1 else FLAGS.visit_weight_envelope_steps
                visit_weight_envelope_delay = FLAGS.walker_weight_envelope_delay if FLAGS.visit_weight_envelope_delay == -1 else FLAGS.visit_weight_envelope_delay
                visit_weight = apply_envelope(
                    type=FLAGS.visit_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.visit_weight,
                    growing_steps=visit_weight_envelope_steps,
                    delay=visit_weight_envelope_delay)
                walker_weight = apply_envelope(
                    type=FLAGS.walker_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.walker_weight,
                    growing_steps=FLAGS.walker_weight_envelope_steps,
                    delay=FLAGS.walker_weight_envelope_delay)
                tf.summary.scalar('Weights_Visit', visit_weight)
                tf.summary.scalar('Weights_Walker', walker_weight)

                if FLAGS.unsup_samples != 0:
                    model.add_semisup_loss(t_sup_emb,
                                           t_unsup_emb,
                                           t_sup_labels,
                                           visit_weight=visit_weight,
                                           walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)

            # Set up learning rate schedule if necessary.
            if FLAGS.custom_lr_vals is not None and FLAGS.custom_lr_steps is not None:
                boundaries = [
                    tf.convert_to_tensor(x, tf.int64)
                    for x in FLAGS.custom_lr_steps
                ]

                t_learning_rate = piecewise_constant(model.step, boundaries,
                                                     FLAGS.custom_lr_vals)
            else:
                t_learning_rate = tf.maximum(
                    tf.train.exponential_decay(FLAGS.learning_rate,
                                               model.step,
                                               FLAGS.decay_steps,
                                               FLAGS.decay_factor,
                                               staircase=True),
                    FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            train_op = model.create_train_op(t_learning_rate)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)

            ''' BEGIN EVIL STUFF !!!
            checkpoint_path = '/usr/wiss/haeusser/experiments/inception_v4/model.ckpt'
            mapping = dict()
            for x in slim.get_model_variables():
                name = x.name[:-2]
                ok = True
                for banned in ['Logits', 'fc1', 'fully_connected', 'ExponentialMovingAverage']:
                    if banned in name:
                        ok = False
                if ok:
                    mapping[name[4:]] = x

            init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
                checkpoint_path, mapping)

            # Create an initial assignment function.
            def InitAssignFn(sess):
                sess.run(init_assign_op, init_feed_dict)
                print("#################################### Checkpoint loaded.")
             '''  # END EVIL STUFF !!!

            slim.learning.train(
                train_op,
                # init_fn=InitAssignFn,  # EVIL, too
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
            )
def main(_):
    if FLAGS.logdir is not None:
        FLAGS.logdir = FLAGS.logdir + '/t_mnist_eval'
        try:
            shutil.rmtree(FLAGS.logdir)
        except OSError as e:
            print ("Error: %s - %s." % (e.filename, e.strerror))

    train_images, train_labels = mnist_tools.get_data('train')
    test_images, test_labels = mnist_tools.get_data('test')

    # Sample labeled training subset.
    if FLAGS.sup_seed >= 0:
      seed = FLAGS.sup_seed
    elif FLAGS.sup_seed == -2:
      seed = FLAGS.sup_per_class
    else:
      seed = np.random.randint(0, 1000)

    print('Seed:', seed)
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, NUM_LABELS, seed)


    graph = tf.Graph()
    with graph.as_default():
        model = semisup.SemisupModel(semisup.architectures.mnist_model, NUM_LABELS,
                                     IMAGE_SHAPE, dropout_keep_prob=FLAGS.dropout_keep_prob)

        # Set up inputs.
        t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                    sup_by_label, FLAGS.sup_per_batch)

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        if FLAGS.semisup:
            t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                                         FLAGS.unsup_batch_size)

            t_unsup_emb = model.image_to_embedding(t_unsup_images)
            model.add_semisup_loss(
                    t_sup_emb, t_unsup_emb, t_sup_labels,
                    walker_weight=FLAGS.walker_weight, visit_weight=FLAGS.visit_weight)

            #model.add_emb_regularization(t_unsup_emb, weight=FLAGS.l1_weight)

        model.add_logit_loss(t_sup_logit, t_sup_labels, weight=FLAGS.logit_weight)

        #model.add_emb_regularization(t_sup_emb, weight=FLAGS.l1_weight)

        t_learning_rate = tf.placeholder("float", shape=[])

        train_op = model.create_train_op(t_learning_rate)

        summary_op = tf.summary.merge_all()
        if FLAGS.logdir is not None:
            summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)
            saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        learning_rate_ = FLAGS.learning_rate

        for step in range(FLAGS.max_steps):
            lr = learning_rate_
            if step < FLAGS.warmup_steps:
                lr = 1e-6 + semisup.apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)

            _, summaries = sess.run([train_op, summary_op], {
              t_learning_rate: lr
            })

            sys.stderr.write("\rstep: %d" % step)
            sys.stdout.flush()
    
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('\nStep: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)

    
                if FLAGS.logdir is not None:
                    sum_values = {
                    'Test error': test_err
                }
                summary_writer.add_summary(summaries, step)
                for key, value in sum_values.items():
                    summary = tf.Summary(
                            value=[tf.Summary.Value(tag=key, simple_value=value)])
                    summary_writer.add_summary(summary, step)

            if step % FLAGS.decay_steps == 0 and step > 0:
                learning_rate_ = learning_rate_ * FLAGS.decay_factor

        coord.request_stop()
        coord.join(threads)

    print('FINAL RESULTS:')
    print('Test error: %.2f %%' % (test_err))
    print('final_score', 1 - test_err/100)

    print('@@test_error:%.4f' % (test_err/100))
    print('@@train_loss:%.4f' % 0)
    print('@@reg_loss:%.4f' % 0)
    print('@@estimated_error:%.4f' % 0)
    print('@@centroid_norm:%.4f' % 0)
    print('@@emb_norm:%.4f' % 0)
    print('@@k_score:%.4f' % 0)
    print('@@svm_score:%.4f' % 0)
Exemple #8
0
def init_graph(sup_by_label,
               train_images,
               train_labels,
               test_images,
               test_labels,
               logdir,
               unsup_batch_size=None):
    if unsup_batch_size == None:
        unsup_batch_size = FLAGS.unsup_batch_size

    graph = tf.Graph()
    with graph.as_default():
        model_func = semisup.architectures.mnist_model
        if FLAGS.dropout_keep_prob < 1:
            model_func = semisup.architectures.mnist_model_dropout

        model = semisup.SemisupModel(model_func,
                                     NUM_LABELS,
                                     IMAGE_SHAPE,
                                     dropout_keep_prob=FLAGS.dropout_keep_prob)

        # Set up inputs.
        t_sup_images = tf.placeholder("float", shape=[None] + IMAGE_SHAPE)
        t_sup_labels = tf.placeholder(dtype=tf.int32, shape=[
            None,
        ])

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        t_unsup_images = tf.placeholder("float", shape=[None] + IMAGE_SHAPE)

        proximity_weight = tf.placeholder("float", shape=[])
        visit_weight = tf.placeholder("float", shape=[])
        walker_weight = tf.placeholder("float", shape=[])
        t_logit_weight = tf.placeholder("float", shape=[])
        t_l1_weight = tf.placeholder("float", shape=[])
        t_learning_rate = tf.placeholder("float", shape=[])

        t_unsup_emb = model.image_to_embedding(t_unsup_images)
        model.add_semisup_loss(t_sup_emb,
                               t_unsup_emb,
                               t_sup_labels,
                               walker_weight=walker_weight,
                               visit_weight=visit_weight,
                               proximity_weight=proximity_weight,
                               normalize_along_classes=True)
        model.add_logit_loss(t_sup_logit, t_sup_labels, weight=t_logit_weight)

        model.add_emb_regularization(t_sup_emb, weight=t_l1_weight)
        model.add_emb_regularization(t_unsup_emb, weight=t_l1_weight)

        train_op = model.create_train_op(t_learning_rate)
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(logdir, graph)

        saver = tf.train.Saver()
        unsup_images_iterator = semisup.create_input(train_images,
                                                     train_labels,
                                                     unsup_batch_size)

        sess = tf.Session(graph=graph)
        sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    def init_iterators():
        input_graph = tf.Graph()
        with input_graph.as_default():
            sup_images_iterator, sup_labels_iterator = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            input_sess = tf.Session(graph=input_graph)
            input_sess.run(tf.global_variables_initializer())

        input_coord = tf.train.Coordinator()
        input_threads = tf.train.start_queue_runners(sess=input_sess,
                                                     coord=input_coord)

        return sup_images_iterator, sup_labels_iterator, input_sess, input_threads, input_coord

    def train_warmup():
        """ if there are few samples, warmup at first"""
        sup_images_iterator, sup_labels_iterator, input_sess, input_threads, input_coord = init_iterators(
        )

        model.reset_optimizer(sess)

        for step in range(FLAGS.max_steps):
            unsup_images, _ = sess.run(unsup_images_iterator)
            si, sl = input_sess.run([sup_images_iterator, sup_labels_iterator])

            _, summaries, train_loss = sess.run(
                [train_op, summary_op, model.train_loss],
                {
                    t_unsup_images:
                    unsup_images,
                    walker_weight:
                    FLAGS.walker_weight,
                    proximity_weight:
                    0,
                    visit_weight:
                    0.1 +
                    apply_envelope("lin", step, 0.4, FLAGS.warmup_steps, 0),
                    t_logit_weight:
                    0.5,  # FLAGS.logit_weight,
                    t_l1_weight:
                    FLAGS.l1_weight,
                    t_sup_images:
                    si,
                    t_sup_labels:
                    sl,
                    t_learning_rate:
                    5e-5 + apply_envelope("log", step, FLAGS.learning_rate,
                                          FLAGS.warmup_steps, 0)
                })

            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print('Train loss: %.2f ' % train_loss)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

        input_coord.request_stop()
        input_coord.join(input_threads)
        input_sess.close()

    def train_finetune(lr=0.001, steps=FLAGS.max_steps):
        sup_images_iterator, sup_labels_iterator, input_sess, input_threads, input_coord = init_iterators(
        )

        model.reset_optimizer(sess)
        for step in range(int(steps)):
            unsup_images, _ = sess.run(unsup_images_iterator)
            si, sl = input_sess.run([sup_images_iterator, sup_labels_iterator])
            _, summaries, train_loss = sess.run(
                [train_op, summary_op, model.train_loss], {
                    t_unsup_images: unsup_images,
                    walker_weight: 1,
                    proximity_weight: 0,
                    visit_weight: 1,
                    t_l1_weight: 0,
                    t_logit_weight: 1,
                    t_learning_rate: lr,
                    t_sup_images: si,
                    t_sup_labels: sl
                })
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print('Train loss: %.2f ' % train_loss)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

        input_coord.request_stop()
        input_coord.join(input_threads)
        input_sess.close()

    def choose_sample(method="propose_samples"):
        sup_images_iterator, sup_labels_iterator, input_sess, input_threads, input_coord = init_iterators(
        )

        sup_lbls = np.hstack(
            [np.ones(len(i)) * ind for ind, i in enumerate(sup_by_label)])
        sup_images = np.vstack(sup_by_label)

        print("total number of supervised samples", sup_lbls.shape)

        n_samples = FLAGS.num_new_samples_per_run
        propose_func = getattr(model, method)
        inds_lba = propose_func(sup_images,
                                sup_lbls,
                                train_images,
                                train_labels,
                                sess,
                                n_samples=n_samples)
        print('sampled new images with labels', train_labels[inds_lba])

        input_coord.request_stop()
        input_coord.join(input_threads)
        input_sess.close()

        return inds_lba

    def add_sample(index):
        # add sample
        label = train_labels[index]
        img = train_images[index]
        sup_by_label[label] = np.vstack([sup_by_label[label], [img]])

    return train_warmup, train_finetune, choose_sample, add_sample
Exemple #9
0
def Eval(FLAGS):
    # Get dataset-related toolbox.
    dataset_tools = import_module('semisup.tools.' + FLAGS.dataset)
    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    test_images, test_labels = dataset_tools.get_data('test')

    graph = tf.Graph()
    with graph.as_default():

        # Set up input pipeline.
        #image, label = tf.train.slice_input_producer([test_images, test_labels])
        #images, labels = tf.train.batch(
        #    [image, label], batch_size=FLAGS.eval_batch_size)
        images, labels = semisup.create_input(test_images, test_labels,
                                              FLAGS.eval_batch_size)

        images = tf.cast(images, tf.float32)
        labels = tf.cast(labels, tf.int64)

        # Reshape if necessary.
        if FLAGS.new_size > 0:
            new_shape = [FLAGS.new_size, FLAGS.new_size, 3]
        else:
            new_shape = None

        # Create function that defines the network.
        model_function = partial(architecture,
                                 is_training=False,
                                 new_shape=new_shape,
                                 img_shape=image_shape,
                                 augmentation_function=None,
                                 image_summary=False,
                                 emb_size=FLAGS.emb_size)

        # Set up semisup model.
        model = semisup.SemisupModel(model_function,
                                     num_labels,
                                     image_shape,
                                     test_in=images)

        # Add moving average variables.
        for var in tf.get_collection('moving_vars'):
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
        for var in slim.get_model_variables():
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)

        # Get prediction tensor from semisup model.
        predictions = tf.argmax(model.test_logit, 1)

        # Accuracy metric for summaries.
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy':
            slim.metrics.streaming_accuracy(predictions, labels),
        })
        for name, value in names_to_values.iteritems():
            tf.summary.scalar(name, value)

        # Run the actual evaluation loop.
        num_batches = math.ceil(
            len(test_labels) / float(FLAGS.eval_batch_size))

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        #checkpoint_path = tf.train.latest_checkpoint(FLAGS.logdir + '/train')
        metric_values = slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.logdir + '/' + str(FLAGS.unsup_samples) +
            '/train',
            logdir=FLAGS.logdir + '/' + str(FLAGS.unsup_samples) + '/eval',
            num_evals=num_batches,
            eval_op=names_to_updates.values(),
            eval_interval_secs=1,
            session_config=config,
            timeout=2,
            final_op=names_to_values.values())
        #print (metric_values)
    return metric_values
Exemple #10
0
def train_test(FLAGS):

    dataset_tools = import_module('tools.' + FLAGS.dataset)
    train_images, train_labels = dataset_tools.get_data('train')
    if FLAGS.target_dataset is not None:
        target_dataset_tools = import_module('tools.' + FLAGS.target_dataset)
        train_images_unlabeled, train_images_label = target_dataset_tools.get_data(
            FLAGS.target_dataset_split)
    else:
        train_images_unlabeled, train_images_label = dataset_tools.get_data(
            'unlabeled')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))
        #TODO: make smaple slections per classs :done
        #unsup_by_label = semisup.sample_by_label(train_images_unlabeled, train_images_label,
        #                                       FLAGS.unsup_samples/num_labels+num_labels, num_labels,
        #                                       seed)

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled, None,
                                                  FLAGS.unsup_batch_size)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            #print(t_sup_images.shape)
            #with tf.Session() as sess: print (t_sup_images.eval().shape)
            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) generalize augmentation
                def _random_invert(inputs1, _):
                    inputs = tf.cast(inputs1, tf.float32)
                    inputs = tf.image.adjust_brightness(
                        inputs, tf.random_uniform((1, 1), 0.0, 0.5))
                    inputs = tf.image.random_contrast(inputs, 0.3, 1)
                    # inputs = tf.image.per_image_standardization(inputs)
                    inputs = tf.image.random_hue(inputs, 0.05)
                    inputs = tf.image.random_saturation(inputs, 0.5, 1.1)

                    def f1():
                        return tf.abs(inputs)  #annotations

                    def f2():
                        return tf.abs(inputs1)

                    return tf.cond(tf.less(tf.random_uniform([], 0.0, 1), 0.5),
                                   f1, f2)

                augmentation_function = _random_invert
            else:
                augmentation_function = None

            # Create function that defines the network.
            model_function = partial(
                architecture,
                new_shape=new_shape,
                img_shape=image_shape,
                augmentation_function=augmentation_function,
                batch_norm_decay=FLAGS.batch_norm_decay,
                emb_size=FLAGS.emb_size)

            # Set up semisup model.
            model = semisup.SemisupModel(model_function, num_labels,
                                         image_shape)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)

            t_sup_logit = model.embedding_to_logit(t_sup_emb)

            # Add losses.
            if FLAGS.unsup_samples != 0:
                t_unsup_emb = model.image_to_embedding(t_unsup_images)
                visit_weight_envelope_steps = (
                    FLAGS.walker_weight_envelope_steps
                    if FLAGS.visit_weight_envelope_steps == -1 else
                    FLAGS.visit_weight_envelope_steps)
                visit_weight_envelope_delay = (
                    FLAGS.walker_weight_envelope_delay
                    if FLAGS.visit_weight_envelope_delay == -1 else
                    FLAGS.visit_weight_envelope_delay)
                visit_weight = apply_envelope(
                    type=FLAGS.visit_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.visit_weight,
                    growing_steps=visit_weight_envelope_steps,
                    delay=visit_weight_envelope_delay)
                walker_weight = apply_envelope(
                    type=FLAGS.walker_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.walker_weight,
                    growing_steps=FLAGS.walker_weight_envelope_steps,  # pylint:disable=line-too-long
                    delay=FLAGS.walker_weight_envelope_delay)
                tf.summary.scalar('Weights_Visit', visit_weight)
                tf.summary.scalar('Weights_Walker', walker_weight)

                model.add_semisup_loss(t_sup_emb,
                                       t_unsup_emb,
                                       t_sup_labels,
                                       visit_weight=visit_weight,
                                       walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)

            # Set up learning rate
            if FLAGS.learning_rate_type is None:
                t_learning_rate = tf.maximum(
                    tf.train.exponential_decay(FLAGS.learning_rate,
                                               model.step,
                                               FLAGS.decay_steps,
                                               FLAGS.decay_factor,
                                               staircase=True),
                    FLAGS.minimum_learning_rate)
            elif FLAGS.learning_rate_type == 'exp2':
                t_learning_rate = tf.maximum(
                    cyclic_learning_rate(model.step,
                                         FLAGS.minimum_learning_rate,
                                         FLAGS.maximum_learning_rate,
                                         FLAGS.learning_rate_cycle_step,
                                         mode='exp_range',
                                         gamma=0.9999),
                    cyclic_learning_rate(model.step,
                                         FLAGS.minimum_learning_rate,
                                         FLAGS.learning_rate,
                                         FLAGS.learning_rate_cycle_step,
                                         mode='triangular',
                                         gamma=0.9994))

            else:
                t_learning_rate = tf.maximum(
                    cyclic_learning_rate(model.step,
                                         FLAGS.minimum_learning_rate,
                                         FLAGS.learning_rate,
                                         FLAGS.learning_rate_cycle_step,
                                         mode='triangular',
                                         gamma=0.9994),
                    FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            train_op = model.create_train_op(t_learning_rate)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)  # pylint:disable=line-too-long

            final_loss = slim.learning.train(
                train_op,
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
                #session_wrapper=tf_debug.LocalCLIDebugWrapperSession
            )

            print(final_loss)
    return final_loss
def main(_):
  train_images, train_labels = mnist_tools.get_data('train')
  test_images, test_labels = mnist_tools.get_data('test')

  # Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
  sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                         FLAGS.sup_per_class, NUM_LABELS, seed)

  import numpy as np
  if 0:

    indices = [374, 2507, 9755, 12953, 16507, 16873, 23474,
            23909, 30280, 35070, 49603, 50106, 51171, 51726, 51805, 55205, 57251, 57296, 57779, 59154] + \
               [16644, 45576, 52886, 42140, 29201, 7767, 24, 134, 8464, 15022, 15715, 15602, 11030, 3898, 10195, 1454,
                3290, 5293, 5806, 274]
    indices = [374, 2507, 9755, 12953, 16507, 16873, 23474,
            23909, 30280, 35070, 49603, 50106, 51171, 51726, 51805, 55205, 57251, 57296, 57779, 59154] + \
               [9924, 34058, 53476, 15715, 6428, 33598, 33464, 41753, 21250, 26389, 12950,
                12464, 3795, 6761, 5638, 3952, 8300, 5632, 1475, 1875]
    sup_by_label = [ [] for i in range(NUM_LABELS)]
    for ind in indices:
      i = train_labels[ind]
      sup_by_label[i] = sup_by_label[i] + [train_images[ind]]

    for i in range(10):
      sup_by_label[i] = np.asarray(sup_by_label[i])

    sup_by_label = np.asarray(sup_by_label)


  add_random_samples = 20
  if add_random_samples > 0:
    rng = np.random.RandomState()
    indices = rng.choice(len(train_images), add_random_samples, False)

    for i in indices:
      l = train_labels[i]
      sup_by_label[l] = np.vstack([sup_by_label[l], [train_images[i]]])


  graph = tf.Graph()
  with graph.as_default():
    model = semisup.SemisupModel(semisup.architectures.mnist_model, NUM_LABELS,
                                 IMAGE_SHAPE)

    # Set up inputs.
    t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                             FLAGS.unsup_batch_size)
    t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
        sup_by_label, FLAGS.sup_per_batch)

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(t_sup_images)
    t_unsup_emb = model.image_to_embedding(t_unsup_images)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    # Add losses.
    model.add_semisup_loss(
        t_sup_emb, t_unsup_emb, t_sup_labels, visit_weight=FLAGS.visit_weight)
    model.add_logit_loss(t_sup_logit, t_sup_labels)

    t_learning_rate = tf.train.exponential_decay(
        FLAGS.learning_rate,
        model.step,
        FLAGS.decay_steps,
        FLAGS.decay_factor,
        staircase=True)
    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    tf.global_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for step in range(FLAGS.max_steps):
      _, summaries = sess.run([train_op, summary_op])
      if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
        print('Step: %d' % step)
        test_pred = model.classify(test_images).argmax(-1)
        conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
        test_err = (test_labels != test_pred).mean() * 100
        print(conf_mtx)
        print('Test error: %.2f %%' % test_err)
        print()

        test_summary = tf.Summary(
            value=[tf.Summary.Value(
                tag='Test Err', simple_value=test_err)])

        summary_writer.add_summary(summaries, step)
        summary_writer.add_summary(test_summary, step)

        saver.save(sess, FLAGS.logdir, model.step)

    coord.request_stop()
    coord.join(threads)
Exemple #12
0
def main(argv):
    del argv
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    np.random.seed = seed

    # Load data.
    if FLAGS.dataset.startswith('gs'):
        if FLAGS.target_dataset is not None:
            raise NotImplemented(
                'target dataset is not supported for GS filesystem')
        train_images, train_labels = gs_data.get_data('train', FLAGS.dataset)
        train_images_unlabeled, _ = gs_data.get_data('unlabeled',
                                                     FLAGS.dataset)
        test_images, test_labels = gs_data.get_data('test', FLAGS.dataset)
    else:
        dataset_tools = import_module('tools.' + FLAGS.dataset)
        train_images, train_labels = dataset_tools.get_data('train')
        if FLAGS.target_dataset is not None:
            target_dataset_tools = import_module('tools.' +
                                                 FLAGS.target_dataset)
            train_images_unlabeled, _ = target_dataset_tools.get_data(
                FLAGS.target_dataset_split)
        else:
            train_images_unlabeled, _ = dataset_tools.get_data('unlabeled')

        test_images, test_labels = dataset_tools.get_data('test')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    tokenizer = None
    encoder = None
    nb_words = None
    seq_len = None
    if architecture == semisup.architectures.cnn_sentiment_model:
        if FLAGS.dataset.startswith('gs'):
            dataset_tools = gs_data
            with BytesIO(
                    file_io.read_file_to_string('{}/tokenizer.pkl'.format(
                        FLAGS.dataset))) as fin:
                tokenizer, encoder, nb_words, seq_len = pickle.load(fin)
        else:
            with open(FLAGS.dictionaries, 'rb') as fin:
                tokenizer, encoder, nb_words, seq_len = pickle.load(fin)

        dataset_tools.IMAGE_SHAPE = [seq_len]
        dataset_tools.NUM_LABELS = len(encoder.classes_)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):
            tf.set_random_seed(seed)

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled,
                                                  None,
                                                  FLAGS.unsup_batch_size,
                                                  seed=seed,
                                                  shuffle=FLAGS.shuffle_input)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label,
                FLAGS.sup_per_batch,
                seed=seed,
                shuffle=FLAGS.shuffle_input)

            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) generalize augmentation
                def _random_invert(inputs, _):
                    randu = tf.random_uniform(
                        shape=[FLAGS.sup_per_batch * num_labels],
                        minval=0.,
                        maxval=1.,
                        dtype=tf.float32)
                    randu = tf.cast(tf.less(randu, 0.5), tf.float32)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    inputs = tf.cast(inputs, tf.float32)
                    return tf.abs(inputs - 255 * randu)

                augmentation_function = _random_invert
            else:
                augmentation_function = None

            # Create function that defines the network.
            if architecture == semisup.architectures.cnn_sentiment_model:

                embedding_weights = None
                dim = None
                if FLAGS.w2v is not None:
                    d = dict()
                    for path in FLAGS.w2v.split(','):
                        tf.logging.info('Loading word2vec: {}'.format(path))
                        if path.startswith('gs'):
                            s = file_io.read_file_to_string(path)
                            try:
                                s = unicode(s, errors='replace')
                            except:
                                pass

                            with StringIO(s) as fin:
                                # w2v = KeyedVectors.load_word2vec_format(fin, unicode_errors='ignore')
                                d = utils.load_word2vec(fin, d)
                        else:
                            # w2v = KeyedVectors.load_word2vec_format(path, unicode_errors='ignore')
                            with open(path, 'r') as fin:
                                d = utils.load_word2vec(fin, d)

                        if dim is None:
                            # dim = w2v.vector_size
                            dim = len(d[list(d.keys())[0]])

                            # for w in w2v.vocab:
                            #     d[w] = w2v[w]

                    embedding_weights = utils.get_embedding_weights(
                        nb_words, tokenizer, d, dim)

                if dim is None:
                    dim = FLAGS.word_embedding_dim

                model_function = partial(
                    architecture,
                    nb_words=nb_words,
                    embedding_dim=dim,
                    static_embedding=FLAGS.static_word_embeddings == 0,
                    embedding_weights=embedding_weights,
                    seed=seed,
                    embedding_dropout=FLAGS.embedding_dropout,
                    ############################################################
                    img_shape=image_shape,
                    batch_norm_decay=FLAGS.batch_norm_decay,
                    emb_size=FLAGS.emb_size)
            else:
                model_function = partial(
                    architecture,
                    new_shape=new_shape,
                    img_shape=image_shape,
                    augmentation_function=augmentation_function,
                    batch_norm_decay=FLAGS.batch_norm_decay,
                    emb_size=FLAGS.emb_size)

            ti = tf.constant(test_images,
                             dtype=np.float32,
                             shape=list(test_images.shape),
                             name='test_in')
            tl = tf.constant(test_labels,
                             dtype=np.float32,
                             shape=list(test_labels.shape),
                             name='test_label')
            # Set up semisup model.
            model = semisup.SemisupModel(model_function,
                                         num_labels,
                                         image_shape,
                                         test_in=ti,
                                         test_label=tl)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)
            t_unsup_emb = model.image_to_embedding(t_unsup_images)

            # Add virtual embeddings.
            if FLAGS.virtual_embeddings:
                t_sup_emb = tf.concat(0, [
                    t_sup_emb,
                    semisup.create_virt_emb(FLAGS.virtual_embeddings,
                                            FLAGS.emb_size)
                ])

                if not FLAGS.remove_classes:
                    # need to add additional labels for virtual embeddings
                    t_sup_labels = tf.concat(0, [
                        t_sup_labels,
                        (num_labels +
                         tf.range(1, FLAGS.virtual_embeddings + 1, tf.int64)) *
                        tf.ones([FLAGS.virtual_embeddings], tf.int64)
                    ])

            t_sup_logit = model.embedding_to_logit(t_sup_emb, seed=seed)

            # Add losses.
            visit_weight_envelope_steps = (FLAGS.walker_weight_envelope_steps
                                           if FLAGS.visit_weight_envelope_steps
                                           == -1 else
                                           FLAGS.visit_weight_envelope_steps)
            visit_weight_envelope_delay = (FLAGS.walker_weight_envelope_delay
                                           if FLAGS.visit_weight_envelope_delay
                                           == -1 else
                                           FLAGS.visit_weight_envelope_delay)
            visit_weight = apply_envelope(
                type=FLAGS.visit_weight_envelope,
                step=model.step,
                final_weight=FLAGS.visit_weight,
                growing_steps=visit_weight_envelope_steps,
                delay=visit_weight_envelope_delay)
            walker_weight = apply_envelope(
                type=FLAGS.walker_weight_envelope,
                step=model.step,
                final_weight=FLAGS.walker_weight,
                growing_steps=FLAGS.walker_weight_envelope_steps,  # pylint:disable=line-too-long
                delay=FLAGS.walker_weight_envelope_delay)
            tf.summary.scalar('Weights_Visit', visit_weight)
            tf.summary.scalar('Weights_Walker', walker_weight)

            if FLAGS.unsup_samples != 0:
                model.add_semisup_loss(t_sup_emb,
                                       t_unsup_emb,
                                       t_sup_labels,
                                       visit_weight=visit_weight,
                                       walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)
            model.add_accuracy(t_sup_logit, t_sup_labels)
            model.add_test_accuracy()

            # Set up learning rate
            t_learning_rate = tf.maximum(
                tf.train.exponential_decay(FLAGS.learning_rate,
                                           model.step,
                                           FLAGS.decay_steps,
                                           FLAGS.decay_factor,
                                           staircase=True),
                FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            optimizer = FLAGS.optimizer
            if optimizer == 'None':
                optimizer = None
            train_op = model.create_train_op(t_learning_rate,
                                             optimizer=optimizer)

            if FLAGS.num_cpus > 0:
                num_cpus = FLAGS.num_cpus
                config = tf.ConfigProto(intra_op_parallelism_threads=num_cpus,
                                        inter_op_parallelism_threads=num_cpus,
                                        allow_soft_placement=True,
                                        device_count={'CPU': num_cpus})
            else:
                config = tf.ConfigProto()

            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)  # pylint:disable=line-too-long

            slim.learning.train(
                train_op,
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
            )
Exemple #13
0
def main(_):

# Sample labeled training subset.

  train_images, train_labels = data.load_training_data(FLAGS.cifar)
  test_images, test_labels = data.load_test_data(FLAGS.cifar)

  print(train_images.shape, train_labels.shape)

  # Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else np.random.randint(0, 1000)
  print('Seed:', seed)
  sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                         FLAGS.sup_per_class, NUM_LABELS, seed)

  graph = tf.Graph()
  with graph.as_default():

    def augment(image):
        # image_size = 28
        # image = tf.image.resize_image_with_crop_or_pad(
        #    image, image_size+4, image_size+4)
        # image = tf.random_crop(image, [image_size, image_size, 3])
        image = tf.image.random_flip_left_right(image)
        # Brightness/saturation/constrast provides small gains .2%~.5% on cifar.
        # image = tf.image.random_brightness(image, max_delta=63. / 255.)
        # image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        # image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
        return image

    model_func = getattr(semisup.architectures, FLAGS.model)
    model = semisup.SemisupModel(model_func, NUM_LABELS,
                                 IMAGE_SHAPE, emb_size=256, dropout_keep_prob=FLAGS.dropout_keep_prob,
                                 augmentation_function=augment)


    if FLAGS.random_batches:
      sup_lbls = np.asarray(np.hstack([np.ones(len(i)) * ind for ind, i in enumerate(sup_by_label)]), np.int)
      sup_images = np.vstack(sup_by_label)
      batch_size = FLAGS.sup_per_batch * NUM_LABELS
      t_sup_images, t_sup_labels = semisup.create_input(sup_images, sup_lbls, batch_size)
    else:
      t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
        sup_by_label, FLAGS.sup_per_batch)

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(t_sup_images)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    # Add losses.
    if FLAGS.semisup:
      if FLAGS.equal_cls_unsup:
        allimgs_bylabel = semisup.sample_by_label(train_images, train_labels,
                                                  int(NUM_TRAIN_IMAGES / NUM_LABELS), NUM_LABELS, seed)
        t_unsup_images, _ = semisup.create_per_class_inputs(
          allimgs_bylabel, FLAGS.sup_per_batch)
      else:
        t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                                 FLAGS.unsup_batch_size)

      t_unsup_emb = model.image_to_embedding(t_unsup_images)
      proximity_weight = tf.placeholder("float", shape=[])
      visit_weight = tf.placeholder("float", shape=[])
      walker_weight = tf.placeholder("float", shape=[])

      model.add_semisup_loss(
        t_sup_emb, t_unsup_emb, t_sup_labels,
        walker_weight=walker_weight, visit_weight=visit_weight, proximity_weight=proximity_weight)
    model.add_logit_loss(t_sup_logit, t_sup_labels)

    t_learning_rate = tf.train.exponential_decay(
      FLAGS.learning_rate,
      model.step,
      FLAGS.decay_steps,
      FLAGS.decay_factor,
      staircase=True)
    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()
  with tf.Session(graph=graph) as sess:
    tf.global_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for step in range(FLAGS.max_steps):
      feed_dict={}
      if FLAGS.semisup:

        if FLAGS.proximity_loss:
          feed_dict = {
            walker_weight:  FLAGS.walker_weight,#0.1 + apply_envelope("lin", step, FLAGS.walker_weight, 500, FLAGS.decay_steps),
            visit_weight: 0,
            proximity_weight: FLAGS.visit_weight#0.1 + apply_envelope("lin", step,FLAGS.visit_weight, 500, FLAGS.decay_steps),
          #t_logit_weight: FLAGS.logit_weight,
          #t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
          }
        else:
          feed_dict =  {
            walker_weight:  FLAGS.walker_weight,#0.1 + apply_envelope("lin", step, FLAGS.walker_weight, 500, FLAGS.decay_steps),
            visit_weight: FLAGS.visit_weight, #0.1 + apply_envelope("lin", step,FLAGS.visit_weight, 500, FLAGS.decay_steps),
            proximity_weight: 0
          #t_logit_weight: FLAGS.logit_weight,
          #t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
          }

      _, summaries, train_loss = sess.run([train_op, summary_op, model.train_loss], feed_dict)


      if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
        print('Step: %d' % step)
        test_pred = model.classify(test_images, sess).argmax(-1)
        conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
        test_err = (test_labels != test_pred).mean() * 100
        print(conf_mtx)
        print('Test error: %.2f %%' % test_err)
        print()

        test_summary = tf.Summary(
            value=[tf.Summary.Value(
                tag='Test Err', simple_value=test_err)])

        summary_writer.add_summary(summaries, step)
        summary_writer.add_summary(test_summary, step)

        #saver.save(sess, FLAGS.logdir, model.step)

    coord.request_stop()
    coord.join(threads)