def main(_):
    # Load dataset
    tf.app.flags.FLAGS.data_dir = '/work/haeusser/data/imagenet/shards'
    dataset = ImagenetData(subset='validation')
    assert dataset.data_files()

    num_labels = dataset.num_classes() + 1
    image_shape = [FLAGS.image_size, FLAGS.image_size, 3]

    graph = tf.Graph()
    with graph.as_default():

        images, labels = image_processing.batch_inputs(
            dataset,
            32,
            train=True,
            num_preprocess_threads=16,
            num_readers=FLAGS.num_readers)

        # Set up semisup model.
        model = semisup.SemisupModel(semisup.architectures.inception_model,
                                     num_labels,
                                     image_shape,
                                     test_in=images)

        # Add moving average variables.
        for var in tf.get_collection('moving_vars'):
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
        for var in slim.get_model_variables():
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)

        # Get prediction tensor from semisup model.
        predictions = tf.argmax(model.test_logit, 1)

        # Accuracy metric for summaries.
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy':
            slim.metrics.streaming_accuracy(predictions, labels),
        })
        for name, value in names_to_values.iteritems():
            tf.summary.scalar(name, value)

        # Run the actual evaluation loop.
        num_batches = math.ceil(dataset.num_examples_per_epoch() /
                                float(FLAGS.eval_batch_size))

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.logdir,
            logdir=FLAGS.logdir,
            num_evals=num_batches,
            eval_op=names_to_updates.values(),
            eval_interval_secs=FLAGS.eval_interval_secs,
            session_config=config)
Ejemplo n.º 2
0
def main(_):
    # Get dataset-related toolbox.
    dataset_tools = import_module('semisup.tools.' + FLAGS.dataset)
    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE
    test_images, test_labels = dataset_tools.get_data('test')
    print (test_images)

    graph = tf.Graph()
    with graph.as_default():

        # Set up input pipeline.
        #image, label = tf.train.slice_input_producer([test_images, test_labels])
        #images, labels = tf.train.batch(
        #    [image, label], batch_size=FLAGS.eval_batch_size)
        images, labels = semisup.create_input(test_images,test_labels,FLAGS.eval_batch_size)

        images = tf.cast(images, tf.float32)
        labels = tf.cast(labels, tf.int64)

        # Reshape if necessary.
        if FLAGS.new_size > 0:
            new_shape = [FLAGS.new_size, FLAGS.new_size, 3]
        else:
            new_shape = None

        if FLAGS.augmentation:
            # TODO(haeusser) generalize augmentation
            def _random_invert(inputs1, _):
                inputs = tf.cast(inputs1, tf.float32)
                inputs = tf.image.adjust_brightness(inputs, tf.random_uniform((1, 1), 0.0, 0.5))
                inputs = tf.image.random_contrast(inputs, 0.3, 1)
                # inputs = tf.image.per_image_standardization(inputs)
                inputs = tf.image.random_hue(inputs, 0.05)
                inputs = tf.image.random_saturation(inputs, 0.5, 1.1)

                def f1(): return tf.abs(inputs)  # annotations

                def f2(): return tf.abs(inputs1)

                return tf.cond(tf.less(tf.random_uniform([], 0.0, 1), 0.5), f1, f2)

            augmentation_function = _random_invert
        else:
            augmentation_function = None

        # Create function that defines the network.
        model_function = partial(
            architecture,
            is_training=False,
            new_shape=new_shape,
            img_shape=image_shape,
            augmentation_function=augmentation_function,
            image_summary=False,
            emb_size=FLAGS.emb_size)


        # Set up semisup model.
        model = semisup.SemisupModel(
            model_function,
            num_labels,
            image_shape,
            test_in=images)

        # Add moving average variables.
        for var in tf.get_collection('moving_vars'):
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
        for var in slim.get_model_variables():
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)

        # Get prediction tensor from semisup model.
        predictions = tf.argmax(model.test_logit, 1)
        cmatrix = tf.confusion_matrix(labels,predictions,num_labels)
        # Accuracy metric for summaries.
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        })
        for name, value in names_to_values.iteritems():
            tf.summary.scalar(name, value)
        confusion_image = tf.reshape(tf.cast(cmatrix, tf.float32),
                                     [1,10, 10, 1])
        tf.summary.tensor_summary('cmatrix', cmatrix)
        tf.summary.image('confusion matrix',confusion_image)
        tf_heatmap = tfplot.wrap_axesplot(sns.heatmap, figsize=(7,5), cbar=False, annot=True,yticklabels=dataset_tools.CLASSES, cmap='jet')
        tf.summary.image("heat_maps", tf.reshape(tf_heatmap(cmatrix),[1,500,700,4]))
        # Run the actual evaluation loop.
        num_batches = math.ceil(len(test_labels) / float(FLAGS.eval_batch_size))

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.logdir + '/train',
            logdir=FLAGS.logdir + '/eval',
            num_evals=num_batches,
            eval_op=tf.Print(list(names_to_updates.values()),[confusion_image], message="cmatrix:", summarize=500),
            eval_interval_secs=FLAGS.eval_interval_secs,
            session_config=config,
            timeout=FLAGS.timeout,
            #hooks=[tf_debug.LocalCLIDebugHook(ui_type="readline")]
        )
Ejemplo n.º 3
0
def main(_):
    # Load data.
    dataset_tools = import_module('tools.' + FLAGS.dataset)
    train_images, train_labels = dataset_tools.get_data('train')
    if FLAGS.target_dataset is not None:
        target_dataset_tools = import_module('tools.' + FLAGS.target_dataset)
        train_images_unlabeled, _ = target_dataset_tools.get_data(
            FLAGS.target_dataset_split)
    else:
        train_images_unlabeled, _ = dataset_tools.get_data('unlabeled')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled, None,
                                                  FLAGS.unsup_batch_size)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) revert this to the general case
                def _random_invert(inputs, _):
                    randu = tf.random_uniform(
                        shape=[FLAGS.sup_per_batch * num_labels],
                        minval=0.,
                        maxval=1.,
                        dtype=tf.float32)
                    randu = tf.cast(tf.less(randu, 0.5), tf.float32)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    inputs = tf.cast(inputs, tf.float32)
                    return tf.abs(inputs - 255 * randu)

                augmentation_function = _random_invert

                # if hasattr(dataset_tools, 'augmentation_params'):
                #    augmentation_function = partial(
                #        apply_augmentation, params=dataset_tools.augmentation_params)
                # else:
                #    augmentation_function = apply_affine_augmentation
            else:
                augmentation_function = None

            # Create function that defines the network.
            model_function = partial(
                architecture,
                new_shape=new_shape,
                img_shape=image_shape,
                augmentation_function=augmentation_function,
                batch_norm_decay=FLAGS.batch_norm_decay,
                emb_size=FLAGS.emb_size)

            # Set up semisup model.
            model = semisup.SemisupModel(model_function, num_labels,
                                         image_shape)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)
            t_unsup_emb = model.image_to_embedding(t_unsup_images)

            # Add virtual embeddings.
            if FLAGS.virtual_embeddings:
                t_sup_emb = tf.concat(0, [
                    t_sup_emb,
                    semisup.create_virt_emb(FLAGS.virtual_embeddings,
                                            FLAGS.emb_size)
                ])

                if not FLAGS.remove_classes:
                    # need to add additional labels for virtual embeddings
                    t_sup_labels = tf.concat(0, [
                        t_sup_labels,
                        (num_labels +
                         tf.range(1, FLAGS.virtual_embeddings + 1, tf.int64)) *
                        tf.ones([FLAGS.virtual_embeddings], tf.int64)
                    ])

            t_sup_logit = model.embedding_to_logit(t_sup_emb)

            # Add losses.
            if FLAGS.mmd:
                sys.path.insert(0, '/usr/wiss/haeusser/libs/opt-mmd/gan')
                from mmd import mix_rbf_mmd2

                bandwidths = [2.0, 5.0, 10.0, 20.0, 40.0, 80.0]  # original

                t_sup_flat = tf.reshape(t_sup_emb,
                                        [FLAGS.sup_per_batch * num_labels, -1])
                t_unsup_flat = tf.reshape(t_unsup_emb,
                                          [FLAGS.unsup_batch_size, -1])
                mmd_loss = mix_rbf_mmd2(t_sup_flat,
                                        t_unsup_flat,
                                        sigmas=bandwidths)
                tf.losses.add_loss(mmd_loss)
                tf.summary.scalar('MMD_loss', mmd_loss)
            else:
                visit_weight_envelope_steps = FLAGS.walker_weight_envelope_steps if FLAGS.visit_weight_envelope_steps == -1 else FLAGS.visit_weight_envelope_steps
                visit_weight_envelope_delay = FLAGS.walker_weight_envelope_delay if FLAGS.visit_weight_envelope_delay == -1 else FLAGS.visit_weight_envelope_delay
                visit_weight = apply_envelope(
                    type=FLAGS.visit_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.visit_weight,
                    growing_steps=visit_weight_envelope_steps,
                    delay=visit_weight_envelope_delay)
                walker_weight = apply_envelope(
                    type=FLAGS.walker_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.walker_weight,
                    growing_steps=FLAGS.walker_weight_envelope_steps,
                    delay=FLAGS.walker_weight_envelope_delay)
                tf.summary.scalar('Weights_Visit', visit_weight)
                tf.summary.scalar('Weights_Walker', walker_weight)

                if FLAGS.unsup_samples != 0:
                    model.add_semisup_loss(t_sup_emb,
                                           t_unsup_emb,
                                           t_sup_labels,
                                           visit_weight=visit_weight,
                                           walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)

            # Set up learning rate schedule if necessary.
            if FLAGS.custom_lr_vals is not None and FLAGS.custom_lr_steps is not None:
                boundaries = [
                    tf.convert_to_tensor(x, tf.int64)
                    for x in FLAGS.custom_lr_steps
                ]

                t_learning_rate = piecewise_constant(model.step, boundaries,
                                                     FLAGS.custom_lr_vals)
            else:
                t_learning_rate = tf.maximum(
                    tf.train.exponential_decay(FLAGS.learning_rate,
                                               model.step,
                                               FLAGS.decay_steps,
                                               FLAGS.decay_factor,
                                               staircase=True),
                    FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            train_op = model.create_train_op(t_learning_rate)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)

            ''' BEGIN EVIL STUFF !!!
            checkpoint_path = '/usr/wiss/haeusser/experiments/inception_v4/model.ckpt'
            mapping = dict()
            for x in slim.get_model_variables():
                name = x.name[:-2]
                ok = True
                for banned in ['Logits', 'fc1', 'fully_connected', 'ExponentialMovingAverage']:
                    if banned in name:
                        ok = False
                if ok:
                    mapping[name[4:]] = x

            init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
                checkpoint_path, mapping)

            # Create an initial assignment function.
            def InitAssignFn(sess):
                sess.run(init_assign_op, init_feed_dict)
                print("#################################### Checkpoint loaded.")
             '''  # END EVIL STUFF !!!

            slim.learning.train(
                train_op,
                # init_fn=InitAssignFn,  # EVIL, too
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
            )
Ejemplo n.º 4
0
def main(_):
  FLAGS.emb_size = 128
  optimizer = 'adam'

  train_images, train_labels = mnist_tools.get_data('train')
  test_images, test_labels = mnist_tools.get_data('test')

  # Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else np.random.randint(0, 1000)
  print('Seed:', seed)

  sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                         FLAGS.sup_per_class, NUM_LABELS, seed)

  graph = tf.Graph()
  with graph.as_default():
    model_func = semisup.architectures.mnist_model
    if FLAGS.dropout_keep_prob < 1:
        model_func = semisup.architectures.mnist_model_dropout

    model = semisup.SemisupModel(model_func, NUM_LABELS,
                                 IMAGE_SHAPE, optimizer=optimizer, emb_size=FLAGS.emb_size,
                                 dropout_keep_prob=FLAGS.dropout_keep_prob)

    # Set up inputs.
    t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
      sup_by_label, FLAGS.sup_per_batch)

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(t_sup_images)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    t_unsup_images = tf.placeholder("float", shape=[None] + IMAGE_SHAPE)

    proximity_weight = tf.placeholder("float", shape=[])
    visit_weight = tf.placeholder("float", shape=[])
    walker_weight = tf.placeholder("float", shape=[])
    t_logit_weight = tf.placeholder("float", shape=[])
    t_l1_weight = tf.placeholder("float", shape=[])
    t_learning_rate = tf.placeholder("float", shape=[])

    t_unsup_emb = model.image_to_embedding(t_unsup_images)
    model.add_semisup_loss(
      t_sup_emb, t_unsup_emb, t_sup_labels,
      walker_weight=walker_weight, visit_weight=visit_weight, proximity_weight=proximity_weight)
    model.add_logit_loss(t_sup_logit, t_sup_labels, weight=t_logit_weight)

    model.add_emb_regularization(t_sup_emb, weight=t_l1_weight)
    model.add_emb_regularization(t_unsup_emb, weight=t_l1_weight)

    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()

  sess = tf.InteractiveSession(graph=graph)

  unsup_images_iterator = semisup.create_input(train_images, train_labels,
                                               FLAGS.unsup_batch_size)
  tf.global_variables_initializer().run()

  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)

  use_new_visit_loss = False
  for step in range(FLAGS.max_steps):
    unsup_images, _ = sess.run(unsup_images_iterator)

    if use_new_visit_loss:
      _, summaries, train_loss = sess.run([train_op, summary_op, model.train_loss], {
        t_unsup_images: unsup_images,
        walker_weight: FLAGS.walker_weight,
        proximity_weight: 0.3 + apply_envelope("lin", step, 0.7, FLAGS.warmup_steps, 0)
                          - apply_envelope("lin", step, FLAGS.visit_weight, 2000, FLAGS.warmup_steps),
        t_logit_weight: FLAGS.logit_weight,
        t_l1_weight: FLAGS.l1_weight,
        visit_weight: apply_envelope("lin", step, FLAGS.visit_weight, 2000, FLAGS.warmup_steps),
        t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
      })
    else:
      _, summaries, train_loss = sess.run([train_op, summary_op, model.train_loss], {
        t_unsup_images: unsup_images,
        walker_weight: FLAGS.walker_weight,
        visit_weight: 0.3 + apply_envelope("lin", step, 0.7, FLAGS.warmup_steps, 0),
        proximity_weight: 0,
        t_logit_weight: FLAGS.logit_weight,
        t_l1_weight: FLAGS.l1_weight,
        t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
      })

    if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
      print('Step: %d' % step)
      test_pred = model.classify(test_images).argmax(-1)
      conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
      test_err = (test_labels != test_pred).mean() * 100
      print(conf_mtx)
      print('Test error: %.2f %%' % test_err)
      print('Train loss: %.2f ' % train_loss)
      print()

      test_summary = tf.Summary(
        value=[tf.Summary.Value(
          tag='Test Err', simple_value=test_err)])

      summary_writer.add_summary(summaries, step)
      summary_writer.add_summary(test_summary, step)
Ejemplo n.º 5
0
def init_graph(sup_by_label,
               train_images,
               train_labels,
               test_images,
               test_labels,
               logdir,
               unsup_batch_size=None):
    if unsup_batch_size == None:
        unsup_batch_size = FLAGS.unsup_batch_size

    graph = tf.Graph()
    with graph.as_default():
        model_func = semisup.architectures.mnist_model
        if FLAGS.dropout_keep_prob < 1:
            model_func = semisup.architectures.mnist_model_dropout

        model = semisup.SemisupModel(model_func,
                                     NUM_LABELS,
                                     IMAGE_SHAPE,
                                     dropout_keep_prob=FLAGS.dropout_keep_prob)

        # Set up inputs.
        t_sup_images = tf.placeholder("float", shape=[None] + IMAGE_SHAPE)
        t_sup_labels = tf.placeholder(dtype=tf.int32, shape=[
            None,
        ])

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        t_unsup_images = tf.placeholder("float", shape=[None] + IMAGE_SHAPE)

        proximity_weight = tf.placeholder("float", shape=[])
        visit_weight = tf.placeholder("float", shape=[])
        walker_weight = tf.placeholder("float", shape=[])
        t_logit_weight = tf.placeholder("float", shape=[])
        t_l1_weight = tf.placeholder("float", shape=[])
        t_learning_rate = tf.placeholder("float", shape=[])

        t_unsup_emb = model.image_to_embedding(t_unsup_images)
        model.add_semisup_loss(t_sup_emb,
                               t_unsup_emb,
                               t_sup_labels,
                               walker_weight=walker_weight,
                               visit_weight=visit_weight,
                               proximity_weight=proximity_weight,
                               normalize_along_classes=True)
        model.add_logit_loss(t_sup_logit, t_sup_labels, weight=t_logit_weight)

        model.add_emb_regularization(t_sup_emb, weight=t_l1_weight)
        model.add_emb_regularization(t_unsup_emb, weight=t_l1_weight)

        train_op = model.create_train_op(t_learning_rate)
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(logdir, graph)

        saver = tf.train.Saver()
        unsup_images_iterator = semisup.create_input(train_images,
                                                     train_labels,
                                                     unsup_batch_size)

        sess = tf.Session(graph=graph)
        sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    def init_iterators():
        input_graph = tf.Graph()
        with input_graph.as_default():
            sup_images_iterator, sup_labels_iterator = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            input_sess = tf.Session(graph=input_graph)
            input_sess.run(tf.global_variables_initializer())

        input_coord = tf.train.Coordinator()
        input_threads = tf.train.start_queue_runners(sess=input_sess,
                                                     coord=input_coord)

        return sup_images_iterator, sup_labels_iterator, input_sess, input_threads, input_coord

    def train_warmup():
        """ if there are few samples, warmup at first"""
        sup_images_iterator, sup_labels_iterator, input_sess, input_threads, input_coord = init_iterators(
        )

        model.reset_optimizer(sess)

        for step in range(FLAGS.max_steps):
            unsup_images, _ = sess.run(unsup_images_iterator)
            si, sl = input_sess.run([sup_images_iterator, sup_labels_iterator])

            _, summaries, train_loss = sess.run(
                [train_op, summary_op, model.train_loss],
                {
                    t_unsup_images:
                    unsup_images,
                    walker_weight:
                    FLAGS.walker_weight,
                    proximity_weight:
                    0,
                    visit_weight:
                    0.1 +
                    apply_envelope("lin", step, 0.4, FLAGS.warmup_steps, 0),
                    t_logit_weight:
                    0.5,  # FLAGS.logit_weight,
                    t_l1_weight:
                    FLAGS.l1_weight,
                    t_sup_images:
                    si,
                    t_sup_labels:
                    sl,
                    t_learning_rate:
                    5e-5 + apply_envelope("log", step, FLAGS.learning_rate,
                                          FLAGS.warmup_steps, 0)
                })

            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print('Train loss: %.2f ' % train_loss)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

        input_coord.request_stop()
        input_coord.join(input_threads)
        input_sess.close()

    def train_finetune(lr=0.001, steps=FLAGS.max_steps):
        sup_images_iterator, sup_labels_iterator, input_sess, input_threads, input_coord = init_iterators(
        )

        model.reset_optimizer(sess)
        for step in range(int(steps)):
            unsup_images, _ = sess.run(unsup_images_iterator)
            si, sl = input_sess.run([sup_images_iterator, sup_labels_iterator])
            _, summaries, train_loss = sess.run(
                [train_op, summary_op, model.train_loss], {
                    t_unsup_images: unsup_images,
                    walker_weight: 1,
                    proximity_weight: 0,
                    visit_weight: 1,
                    t_l1_weight: 0,
                    t_logit_weight: 1,
                    t_learning_rate: lr,
                    t_sup_images: si,
                    t_sup_labels: sl
                })
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print('Train loss: %.2f ' % train_loss)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

        input_coord.request_stop()
        input_coord.join(input_threads)
        input_sess.close()

    def choose_sample(method="propose_samples"):
        sup_images_iterator, sup_labels_iterator, input_sess, input_threads, input_coord = init_iterators(
        )

        sup_lbls = np.hstack(
            [np.ones(len(i)) * ind for ind, i in enumerate(sup_by_label)])
        sup_images = np.vstack(sup_by_label)

        print("total number of supervised samples", sup_lbls.shape)

        n_samples = FLAGS.num_new_samples_per_run
        propose_func = getattr(model, method)
        inds_lba = propose_func(sup_images,
                                sup_lbls,
                                train_images,
                                train_labels,
                                sess,
                                n_samples=n_samples)
        print('sampled new images with labels', train_labels[inds_lba])

        input_coord.request_stop()
        input_coord.join(input_threads)
        input_sess.close()

        return inds_lba

    def add_sample(index):
        # add sample
        label = train_labels[index]
        img = train_images[index]
        sup_by_label[label] = np.vstack([sup_by_label[label], [img]])

    return train_warmup, train_finetune, choose_sample, add_sample
Ejemplo n.º 6
0
def main(argv):
    del argv
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    np.random.seed = seed

    # Load data.
    if FLAGS.dataset.startswith('gs'):
        if FLAGS.target_dataset is not None:
            raise NotImplemented(
                'target dataset is not supported for GS filesystem')
        train_images, train_labels = gs_data.get_data('train', FLAGS.dataset)
        train_images_unlabeled, _ = gs_data.get_data('unlabeled',
                                                     FLAGS.dataset)
        test_images, test_labels = gs_data.get_data('test', FLAGS.dataset)
    else:
        dataset_tools = import_module('tools.' + FLAGS.dataset)
        train_images, train_labels = dataset_tools.get_data('train')
        if FLAGS.target_dataset is not None:
            target_dataset_tools = import_module('tools.' +
                                                 FLAGS.target_dataset)
            train_images_unlabeled, _ = target_dataset_tools.get_data(
                FLAGS.target_dataset_split)
        else:
            train_images_unlabeled, _ = dataset_tools.get_data('unlabeled')

        test_images, test_labels = dataset_tools.get_data('test')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    tokenizer = None
    encoder = None
    nb_words = None
    seq_len = None
    if architecture == semisup.architectures.cnn_sentiment_model:
        if FLAGS.dataset.startswith('gs'):
            dataset_tools = gs_data
            with BytesIO(
                    file_io.read_file_to_string('{}/tokenizer.pkl'.format(
                        FLAGS.dataset))) as fin:
                tokenizer, encoder, nb_words, seq_len = pickle.load(fin)
        else:
            with open(FLAGS.dictionaries, 'rb') as fin:
                tokenizer, encoder, nb_words, seq_len = pickle.load(fin)

        dataset_tools.IMAGE_SHAPE = [seq_len]
        dataset_tools.NUM_LABELS = len(encoder.classes_)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):
            tf.set_random_seed(seed)

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled,
                                                  None,
                                                  FLAGS.unsup_batch_size,
                                                  seed=seed,
                                                  shuffle=FLAGS.shuffle_input)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label,
                FLAGS.sup_per_batch,
                seed=seed,
                shuffle=FLAGS.shuffle_input)

            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) generalize augmentation
                def _random_invert(inputs, _):
                    randu = tf.random_uniform(
                        shape=[FLAGS.sup_per_batch * num_labels],
                        minval=0.,
                        maxval=1.,
                        dtype=tf.float32)
                    randu = tf.cast(tf.less(randu, 0.5), tf.float32)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    inputs = tf.cast(inputs, tf.float32)
                    return tf.abs(inputs - 255 * randu)

                augmentation_function = _random_invert
            else:
                augmentation_function = None

            # Create function that defines the network.
            if architecture == semisup.architectures.cnn_sentiment_model:

                embedding_weights = None
                dim = None
                if FLAGS.w2v is not None:
                    d = dict()
                    for path in FLAGS.w2v.split(','):
                        tf.logging.info('Loading word2vec: {}'.format(path))
                        if path.startswith('gs'):
                            s = file_io.read_file_to_string(path)
                            try:
                                s = unicode(s, errors='replace')
                            except:
                                pass

                            with StringIO(s) as fin:
                                # w2v = KeyedVectors.load_word2vec_format(fin, unicode_errors='ignore')
                                d = utils.load_word2vec(fin, d)
                        else:
                            # w2v = KeyedVectors.load_word2vec_format(path, unicode_errors='ignore')
                            with open(path, 'r') as fin:
                                d = utils.load_word2vec(fin, d)

                        if dim is None:
                            # dim = w2v.vector_size
                            dim = len(d[list(d.keys())[0]])

                            # for w in w2v.vocab:
                            #     d[w] = w2v[w]

                    embedding_weights = utils.get_embedding_weights(
                        nb_words, tokenizer, d, dim)

                if dim is None:
                    dim = FLAGS.word_embedding_dim

                model_function = partial(
                    architecture,
                    nb_words=nb_words,
                    embedding_dim=dim,
                    static_embedding=FLAGS.static_word_embeddings == 0,
                    embedding_weights=embedding_weights,
                    seed=seed,
                    embedding_dropout=FLAGS.embedding_dropout,
                    ############################################################
                    img_shape=image_shape,
                    batch_norm_decay=FLAGS.batch_norm_decay,
                    emb_size=FLAGS.emb_size)
            else:
                model_function = partial(
                    architecture,
                    new_shape=new_shape,
                    img_shape=image_shape,
                    augmentation_function=augmentation_function,
                    batch_norm_decay=FLAGS.batch_norm_decay,
                    emb_size=FLAGS.emb_size)

            ti = tf.constant(test_images,
                             dtype=np.float32,
                             shape=list(test_images.shape),
                             name='test_in')
            tl = tf.constant(test_labels,
                             dtype=np.float32,
                             shape=list(test_labels.shape),
                             name='test_label')
            # Set up semisup model.
            model = semisup.SemisupModel(model_function,
                                         num_labels,
                                         image_shape,
                                         test_in=ti,
                                         test_label=tl)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)
            t_unsup_emb = model.image_to_embedding(t_unsup_images)

            # Add virtual embeddings.
            if FLAGS.virtual_embeddings:
                t_sup_emb = tf.concat(0, [
                    t_sup_emb,
                    semisup.create_virt_emb(FLAGS.virtual_embeddings,
                                            FLAGS.emb_size)
                ])

                if not FLAGS.remove_classes:
                    # need to add additional labels for virtual embeddings
                    t_sup_labels = tf.concat(0, [
                        t_sup_labels,
                        (num_labels +
                         tf.range(1, FLAGS.virtual_embeddings + 1, tf.int64)) *
                        tf.ones([FLAGS.virtual_embeddings], tf.int64)
                    ])

            t_sup_logit = model.embedding_to_logit(t_sup_emb, seed=seed)

            # Add losses.
            visit_weight_envelope_steps = (FLAGS.walker_weight_envelope_steps
                                           if FLAGS.visit_weight_envelope_steps
                                           == -1 else
                                           FLAGS.visit_weight_envelope_steps)
            visit_weight_envelope_delay = (FLAGS.walker_weight_envelope_delay
                                           if FLAGS.visit_weight_envelope_delay
                                           == -1 else
                                           FLAGS.visit_weight_envelope_delay)
            visit_weight = apply_envelope(
                type=FLAGS.visit_weight_envelope,
                step=model.step,
                final_weight=FLAGS.visit_weight,
                growing_steps=visit_weight_envelope_steps,
                delay=visit_weight_envelope_delay)
            walker_weight = apply_envelope(
                type=FLAGS.walker_weight_envelope,
                step=model.step,
                final_weight=FLAGS.walker_weight,
                growing_steps=FLAGS.walker_weight_envelope_steps,  # pylint:disable=line-too-long
                delay=FLAGS.walker_weight_envelope_delay)
            tf.summary.scalar('Weights_Visit', visit_weight)
            tf.summary.scalar('Weights_Walker', walker_weight)

            if FLAGS.unsup_samples != 0:
                model.add_semisup_loss(t_sup_emb,
                                       t_unsup_emb,
                                       t_sup_labels,
                                       visit_weight=visit_weight,
                                       walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)
            model.add_accuracy(t_sup_logit, t_sup_labels)
            model.add_test_accuracy()

            # Set up learning rate
            t_learning_rate = tf.maximum(
                tf.train.exponential_decay(FLAGS.learning_rate,
                                           model.step,
                                           FLAGS.decay_steps,
                                           FLAGS.decay_factor,
                                           staircase=True),
                FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            optimizer = FLAGS.optimizer
            if optimizer == 'None':
                optimizer = None
            train_op = model.create_train_op(t_learning_rate,
                                             optimizer=optimizer)

            if FLAGS.num_cpus > 0:
                num_cpus = FLAGS.num_cpus
                config = tf.ConfigProto(intra_op_parallelism_threads=num_cpus,
                                        inter_op_parallelism_threads=num_cpus,
                                        allow_soft_placement=True,
                                        device_count={'CPU': num_cpus})
            else:
                config = tf.ConfigProto()

            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)  # pylint:disable=line-too-long

            slim.learning.train(
                train_op,
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
            )
Ejemplo n.º 7
0
def main(_):
    train_images, train_labels = mnist_tools.get_data('train')
    test_images, test_labels = mnist_tools.get_data('test')

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, NUM_LABELS,
                                           seed)

    graph = tf.Graph()
    with graph.as_default():
        model = semisup.SemisupModel(semisup.architectures.mnist_model,
                                     NUM_LABELS, IMAGE_SHAPE)

        # Set up inputs.
        t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                                 FLAGS.unsup_batch_size)
        t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
            sup_by_label, FLAGS.sup_per_batch)

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_unsup_emb = model.image_to_embedding(t_unsup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        model.add_semisup_loss(t_sup_emb,
                               t_unsup_emb,
                               t_sup_labels,
                               visit_weight=FLAGS.visit_weight)
        model.add_logit_loss(t_sup_logit, t_sup_labels)

        t_learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                                     model.step,
                                                     FLAGS.decay_steps,
                                                     FLAGS.decay_factor,
                                                     staircase=True)
        train_op = model.create_train_op(t_learning_rate)
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for step in xrange(FLAGS.max_steps):
            _, summaries = sess.run([train_op, summary_op])
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

                saver.save(sess, FLAGS.logdir, model.step)

        coord.request_stop()
        coord.join(threads)
Ejemplo n.º 8
0
def train_test(FLAGS):

    dataset_tools = import_module('tools.' + FLAGS.dataset)
    train_images, train_labels = dataset_tools.get_data('train')
    if FLAGS.target_dataset is not None:
        target_dataset_tools = import_module('tools.' + FLAGS.target_dataset)
        train_images_unlabeled, train_images_label = target_dataset_tools.get_data(
            FLAGS.target_dataset_split)
    else:
        train_images_unlabeled, train_images_label = dataset_tools.get_data(
            'unlabeled')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))
        #TODO: make smaple slections per classs :done
        #unsup_by_label = semisup.sample_by_label(train_images_unlabeled, train_images_label,
        #                                       FLAGS.unsup_samples/num_labels+num_labels, num_labels,
        #                                       seed)

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled, None,
                                                  FLAGS.unsup_batch_size)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            #print(t_sup_images.shape)
            #with tf.Session() as sess: print (t_sup_images.eval().shape)
            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) generalize augmentation
                def _random_invert(inputs1, _):
                    inputs = tf.cast(inputs1, tf.float32)
                    inputs = tf.image.adjust_brightness(
                        inputs, tf.random_uniform((1, 1), 0.0, 0.5))
                    inputs = tf.image.random_contrast(inputs, 0.3, 1)
                    # inputs = tf.image.per_image_standardization(inputs)
                    inputs = tf.image.random_hue(inputs, 0.05)
                    inputs = tf.image.random_saturation(inputs, 0.5, 1.1)

                    def f1():
                        return tf.abs(inputs)  #annotations

                    def f2():
                        return tf.abs(inputs1)

                    return tf.cond(tf.less(tf.random_uniform([], 0.0, 1), 0.5),
                                   f1, f2)

                augmentation_function = _random_invert
            else:
                augmentation_function = None

            # Create function that defines the network.
            model_function = partial(
                architecture,
                new_shape=new_shape,
                img_shape=image_shape,
                augmentation_function=augmentation_function,
                batch_norm_decay=FLAGS.batch_norm_decay,
                emb_size=FLAGS.emb_size)

            # Set up semisup model.
            model = semisup.SemisupModel(model_function, num_labels,
                                         image_shape)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)

            t_sup_logit = model.embedding_to_logit(t_sup_emb)

            # Add losses.
            if FLAGS.unsup_samples != 0:
                t_unsup_emb = model.image_to_embedding(t_unsup_images)
                visit_weight_envelope_steps = (
                    FLAGS.walker_weight_envelope_steps
                    if FLAGS.visit_weight_envelope_steps == -1 else
                    FLAGS.visit_weight_envelope_steps)
                visit_weight_envelope_delay = (
                    FLAGS.walker_weight_envelope_delay
                    if FLAGS.visit_weight_envelope_delay == -1 else
                    FLAGS.visit_weight_envelope_delay)
                visit_weight = apply_envelope(
                    type=FLAGS.visit_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.visit_weight,
                    growing_steps=visit_weight_envelope_steps,
                    delay=visit_weight_envelope_delay)
                walker_weight = apply_envelope(
                    type=FLAGS.walker_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.walker_weight,
                    growing_steps=FLAGS.walker_weight_envelope_steps,  # pylint:disable=line-too-long
                    delay=FLAGS.walker_weight_envelope_delay)
                tf.summary.scalar('Weights_Visit', visit_weight)
                tf.summary.scalar('Weights_Walker', walker_weight)

                model.add_semisup_loss(t_sup_emb,
                                       t_unsup_emb,
                                       t_sup_labels,
                                       visit_weight=visit_weight,
                                       walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)

            # Set up learning rate
            if FLAGS.learning_rate_type is None:
                t_learning_rate = tf.maximum(
                    tf.train.exponential_decay(FLAGS.learning_rate,
                                               model.step,
                                               FLAGS.decay_steps,
                                               FLAGS.decay_factor,
                                               staircase=True),
                    FLAGS.minimum_learning_rate)
            elif FLAGS.learning_rate_type == 'exp2':
                t_learning_rate = tf.maximum(
                    cyclic_learning_rate(model.step,
                                         FLAGS.minimum_learning_rate,
                                         FLAGS.maximum_learning_rate,
                                         FLAGS.learning_rate_cycle_step,
                                         mode='exp_range',
                                         gamma=0.9999),
                    cyclic_learning_rate(model.step,
                                         FLAGS.minimum_learning_rate,
                                         FLAGS.learning_rate,
                                         FLAGS.learning_rate_cycle_step,
                                         mode='triangular',
                                         gamma=0.9994))

            else:
                t_learning_rate = tf.maximum(
                    cyclic_learning_rate(model.step,
                                         FLAGS.minimum_learning_rate,
                                         FLAGS.learning_rate,
                                         FLAGS.learning_rate_cycle_step,
                                         mode='triangular',
                                         gamma=0.9994),
                    FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            train_op = model.create_train_op(t_learning_rate)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)  # pylint:disable=line-too-long

            final_loss = slim.learning.train(
                train_op,
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
                #session_wrapper=tf_debug.LocalCLIDebugWrapperSession
            )

            print(final_loss)
    return final_loss
def main(_):
    if FLAGS.logdir is not None:
        # FLAGS.logdir = FLAGS.logdir + '/t_' + datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
        _unsup_batch_size = FLAGS.unsup_batch_size if FLAGS.semisup else 0
        FLAGS.logdir = "{0}/i{1}_e{2}_s{3}_un{4}_d{5}_{6}_w{7}".format(
            FLAGS.logdir, image_shape[0], FLAGS.emb_size,
            FLAGS.sup_per_class, _unsup_batch_size, FLAGS.decay_steps,
            int(FLAGS.decay_factor * 100), FLAGS.warmup_steps)
        try:
            shutil.rmtree(FLAGS.logdir)
        except OSError as e:
            print("Error: %s - %s." % (e.filename, e.strerror))

    # Load image data from npy file
    train_images, test_images, train_labels, test_labels = dataset_tools.get_data(
        one_hot=False, test_size=FLAGS.test_size, image_shape=image_shape)

    unique, counts = np.unique(test_labels, return_counts=True)
    testset_distribution = dict(zip(unique, counts))

    train_X, train_Y = semisup.sample_by_label_v2(train_images, train_labels,
                                                  FLAGS.sup_per_class,
                                                  NUM_LABELS,
                                                  np.random.randint(0, 100))

    def aug(image, label):
        return apply_augmentation(
            image,
            target_shape=image_shape,
            params=dataset_tools.augmentation_params), label

    def aug_unsup(image):
        return apply_augmentation(image,
                                  target_shape=image_shape,
                                  params=dataset_tools.augmentation_params)

    graph = tf.Graph()
    with graph.as_default():
        # Create function that defines the network.
        architecture = getattr(semisup.architectures, FLAGS.architecture)
        model_function = partial(architecture,
                                 new_shape=None,
                                 img_shape=IMAGE_SHAPE,
                                 batch_norm_decay=FLAGS.batch_norm_decay,
                                 emb_size=FLAGS.emb_size)

        model = semisup.SemisupModel(model_function,
                                     NUM_LABELS,
                                     IMAGE_SHAPE,
                                     emb_size=FLAGS.emb_size,
                                     dropout_keep_prob=FLAGS.dropout_keep_prob)

        # Set up supervised inputs.
        t_images = tf.placeholder("float", shape=[None] + image_shape)
        t_labels = tf.placeholder(train_labels.dtype, shape=[None])
        dataset = Dataset.from_tensor_slices((t_images, t_labels))
        # Apply augmentation
        if FLAGS.augmentation:
            dataset = dataset.map(aug)
        dataset = dataset.shuffle(buffer_size=FLAGS.sup_per_class * NUM_LABELS)
        dataset = dataset.repeat().batch(FLAGS.sup_per_class * NUM_LABELS)
        iterator = dataset.make_initializable_iterator()
        t_sup_images, t_sup_labels = iterator.get_next()

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        if FLAGS.semisup:
            unsup_t_images = tf.placeholder("float",
                                            shape=[None] + image_shape)
            unsup_dataset = Dataset.from_tensor_slices(unsup_t_images)
            # Apply augmentation
            if FLAGS.augmentation:
                unsup_dataset = unsup_dataset.map(aug_unsup)
            unsup_dataset = unsup_dataset.shuffle(
                buffer_size=FLAGS.unsup_batch_size)
            unsup_dataset = unsup_dataset.repeat().batch(
                FLAGS.unsup_batch_size)
            unsup_iterator = unsup_dataset.make_initializable_iterator()
            t_unsup_images = unsup_iterator.get_next()

            t_unsup_emb = model.image_to_embedding(t_unsup_images)
            model.add_semisup_loss(t_sup_emb,
                                   t_unsup_emb,
                                   t_sup_labels,
                                   walker_weight=FLAGS.walker_weight,
                                   visit_weight=FLAGS.visit_weight)

            #model.add_emb_regularization(t_unsup_emb, weight=FLAGS.l1_weight)
        else:
            model.loss_aba = tf.constant(0)
            model.visit_loss = tf.constant(0)

        t_logit_loss = model.add_logit_loss(t_sup_logit,
                                            t_sup_labels,
                                            weight=FLAGS.logit_weight)
        # t_logit_loss = tf.constant(0)

        #model.add_emb_regularization(t_sup_emb, weight=FLAGS.l1_weight)

        t_learning_rate = tf.placeholder("float", shape=[])

        train_op = model.create_train_op(t_learning_rate)

        summary_op = tf.summary.merge_all()

        if FLAGS.logdir is not None:
            summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)
            saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:

        sess.run(iterator.initializer,
                 feed_dict={
                     t_images: train_X,
                     t_labels: train_Y
                 })
        if FLAGS.semisup:
            sess.run(unsup_iterator.initializer,
                     feed_dict={unsup_t_images: train_images})

        tf.global_variables_initializer().run()

        if FLAGS.restore_checkpoint is not None:
            variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                          scope='net')
            restorer = tf.train.Saver(var_list=variables)
            restorer.restore(sess, FLAGS.restore_checkpoint)

        learning_rate_ = FLAGS.learning_rate

        for step in range(FLAGS.max_steps):
            lr = learning_rate_
            if step < FLAGS.warmup_steps:
                lr = 1e-6 + semisup.apply_envelope(
                    "log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)

            step_start_time = time.time()
            _, summaries, train_loss, aba_loss, visit_loss, logit_loss = sess.run(
                [
                    train_op, summary_op, model.train_loss, model.loss_aba,
                    model.visit_loss, t_logit_loss
                ], {t_learning_rate: lr})

            if not FLAGS.run_in_background:
                sys.stderr.write("\rstep: %d, Step time: %.4f sec" %
                                 (step, (time.time() - step_start_time)))
                sys.stdout.flush()
            # sup_images = sess.run(d_sup_image

            if (step +
                    1) % FLAGS.eval_interval == 0 or step == 99 or step == 0:
                print('\n=======================')
                print('Step: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Target:', testset_distribution)
                print('Test error: %.2f%%' % test_err)
                print('Learning rate:', lr)
                print('train_loss:', train_loss)
                if FLAGS.semisup:
                    print('walker_loss_aba:', aba_loss)
                    print('visit_loss:', visit_loss)
                print('logit_loss:', logit_loss)
                print('Image shape:', IMAGE_SHAPE)
                print('emb_size: ', FLAGS.emb_size)
                print('sup_per_class: ', FLAGS.sup_per_class)
                if FLAGS.semisup:
                    print('unsup_batch_size: ', FLAGS.unsup_batch_size)
                print('semisup: ', FLAGS.semisup)
                print('augmentation: ', FLAGS.augmentation)
                print('decay_steps: ', FLAGS.decay_steps)
                print('decay_factor: ', FLAGS.decay_factor)
                print('warmup_steps: ', FLAGS.warmup_steps)
                if FLAGS.semisup:
                    print('walker_weight: ', FLAGS.walker_weight)
                    print('visit_weight: ', FLAGS.visit_weight)
                print('logit_weight: ', FLAGS.logit_weight)
                print('=======================\n')

                if FLAGS.logdir is not None:
                    sum_values = {'Test error': test_err}
                    summary_writer.add_summary(summaries, step)
                    for key, value in sum_values.items():
                        summary = tf.Summary(value=[
                            tf.Summary.Value(tag=key, simple_value=value)
                        ])
                        summary_writer.add_summary(summary, step)

            if FLAGS.logdir is not None and (step + 1) % 5000 == 0:
                path = saver.save(sess, FLAGS.logdir + '/checkpoint',
                                  model.step)
                print('@@model_path:%s' % path)

            if step % FLAGS.decay_steps == 0 and step > 0:
                learning_rate_ = learning_rate_ * FLAGS.decay_factor
def main(_):
  unsup_multiplier = NUM_TRAIN_IMAGES / NUM_LABELS / FLAGS.sup_per_class
  print(unsup_multiplier)

# Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None

  graph = tf.Graph()
  with graph.as_default():
    model = semisup.SemisupModel(semisup.architectures.densenet_model, NUM_LABELS,
                                 IMAGE_SHAPE)

    # Set up inputs.
    train_sup, train_labels_sup = data.build_input(FLAGS.cifar,
                                                   os.path.join(FLAGS.dataset_dir, TRAIN_FILE),
                                                   batch_size=FLAGS.sup_batch_size,
                                                   mode='train',
                                                   subset_factor=unsup_multiplier)

    if FLAGS.unsup_batch_size > 0:
      train_unsup, train_labels_unsup = data.build_input(FLAGS.cifar,
                                                       os.path.join(FLAGS.dataset_dir, TRAIN_FILE),
                                                       batch_size=FLAGS.unsup_batch_size,
                                                       mode='train')

    test_images, test_labels = data.build_input(FLAGS.cifar,
                                                os.path.join(FLAGS.dataset_dir, TEST_FILE),
                                                batch_size=TEST_BATCH_SIZE,
                                                mode='test')

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(train_sup)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    # Add losses.
    if FLAGS.unsup_batch_size > 0:
      t_unsup_emb = model.image_to_embedding(train_unsup)
      model.add_semisup_loss(
            t_sup_emb, t_unsup_emb, train_labels_sup, visit_weight=FLAGS.visit_weight)

    model.add_logit_loss(t_sup_logit, train_labels_sup)

    t_learning_rate = tf.train.exponential_decay(
        FLAGS.learning_rate,
        model.step,
        FLAGS.decay_epochs * steps_per_epoch,
        FLAGS.decay_factor,
        staircase=True)
    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    tf.global_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    last_epoch = -1

    for step in range(FLAGS.max_epochs * int(steps_per_epoch)):
      _, summaries, tl = sess.run([train_op, summary_op, model.train_loss])

      epoch = math.floor(step / steps_per_epoch)
      if (epoch >= 0 and epoch % FLAGS.eval_interval == 0) or epoch == 1:
        if epoch == last_epoch: #don't log twice for same epoch
          continue

        last_epoch = epoch
        num_total_batches = 10000 / TEST_BATCH_SIZE
        print('Epoch: %d' % epoch)

        t_imgs, t_lbls = model.get_images(test_images, test_labels, num_total_batches, sess)
        test_pred = model.classify(t_imgs).argmax(-1)
        conf_mtx = semisup.confusion_matrix(t_lbls, test_pred, NUM_LABELS)
        test_err = (t_lbls != test_pred).mean() * 100
        print(conf_mtx)
        print('Test error: %.2f %%' % test_err)
        print()

        t_imgs, t_lbls = model.get_images(train_sup, train_labels_sup, num_total_batches, sess)
        train_pred = model.classify(t_imgs).argmax(-1)
        train_err = (t_lbls != train_pred).mean() * 100
        print('Train error: %.2f %%' % train_err)

        test_summary = tf.Summary(
            value=[tf.Summary.Value(
                tag='Test Err', simple_value=test_err)])

        summary_writer.add_summary(summaries, step)
        summary_writer.add_summary(test_summary, step)

        saver.save(sess, FLAGS.logdir, model.step)

    coord.request_stop()
    coord.join(threads)
Ejemplo n.º 11
0
def main(_):
    architecture = getattr(semisup.architectures, FLAGS.architecture)

    if FLAGS.dataset.startswith('gs'):
        data = np.load(BytesIO(file_io.read_file_to_string(
            FLAGS.dataset)))['data']
    else:
        data = np.load(FLAGS.dataset)['data']

    tokenizer = None
    encoder = None
    nb_words = None
    seq_len = None
    if architecture == semisup.architectures.cnn_sentiment_model:
        if FLAGS.dictionaries.startswith('gs'):
            with BytesIO(file_io.read_file_to_string(
                    FLAGS.ddictionaries)) as fin:
                tokenizer, encoder, nb_words, seq_len = pickle.load(fin)
        else:
            with open(FLAGS.dictionaries, 'rb') as fin:
                tokenizer, encoder, nb_words, seq_len = pickle.load(fin)

        num_labels = len(encoder.classes_)
        image_shape = [seq_len]

    else:
        raise ValueError(
            'Only cnn_sentiment model is supported in the moment!')

    graph = tf.Graph()
    with graph.as_default():

        # Reshape if necessary.
        if FLAGS.new_size > 0:
            new_shape = [FLAGS.new_size, FLAGS.new_size, 3]
        else:
            new_shape = None

        # Create function that defines the network.
        if architecture == semisup.architectures.cnn_sentiment_model:

            dim = None
            if FLAGS.w2v is not None:
                path = FLAGS.w2v.split(',')[0]
                tf.logging.info('Loading word2vec: {}'.format(path))
                with open(path, 'r') as fin:
                    dim = int(fin.readline().split()[1])

            if dim is None:
                dim = FLAGS.word_embedding_dim

            model_function = partial(
                architecture,
                nb_words=nb_words,
                embedding_dim=dim,
                static_embedding=1,
                embedding_weights=None,
                ############################################################
                img_shape=image_shape,
                emb_size=FLAGS.emb_size)
        else:
            model_function = partial(architecture,
                                     is_training=False,
                                     new_shape=new_shape,
                                     img_shape=image_shape,
                                     augmentation_function=None,
                                     image_summary=False,
                                     emb_size=FLAGS.emb_size)

        # Set up semisup model.
        model = semisup.SemisupModel(model_function, num_labels, image_shape)

        with tf.Session() as sess:
            saver = tf.train.Saver()
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model_path))

            idx = 0
            batch_size = FLAGS.eval_batch_size
            num_elements = data.shape[0]
            res = np.zeros((num_elements, FLAGS.emb_size))

            while idx < num_elements:
                r_idx = min(idx + batch_size, num_elements)

                tf.logging.info('Running: {}-{}/{}'.format(
                    idx, r_idx, num_elements))

                batch = data[idx:r_idx]
                res[idx:r_idx] = model.calc_embedding(batch, model.test_emb)

                idx += batch_size

            tf.logging.info('Saving results to: {}'.format(FLAGS.output))
            np.save(FLAGS.output, res)
Ejemplo n.º 12
0
    def mockModel(self):
        def dummy(*args, **kwargs):
            return np.zeros((100, 128))

        return semisup.SemisupModel(dummy, 10, [1])
def main(_):
  train_images, train_labels = mnist_tools.get_data('train')
  test_images, test_labels = mnist_tools.get_data('test')

  # Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
  sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                         FLAGS.sup_per_class, NUM_LABELS, seed)

  import numpy as np
  if 0:

    indices = [374, 2507, 9755, 12953, 16507, 16873, 23474,
            23909, 30280, 35070, 49603, 50106, 51171, 51726, 51805, 55205, 57251, 57296, 57779, 59154] + \
               [16644, 45576, 52886, 42140, 29201, 7767, 24, 134, 8464, 15022, 15715, 15602, 11030, 3898, 10195, 1454,
                3290, 5293, 5806, 274]
    indices = [374, 2507, 9755, 12953, 16507, 16873, 23474,
            23909, 30280, 35070, 49603, 50106, 51171, 51726, 51805, 55205, 57251, 57296, 57779, 59154] + \
               [9924, 34058, 53476, 15715, 6428, 33598, 33464, 41753, 21250, 26389, 12950,
                12464, 3795, 6761, 5638, 3952, 8300, 5632, 1475, 1875]
    sup_by_label = [ [] for i in range(NUM_LABELS)]
    for ind in indices:
      i = train_labels[ind]
      sup_by_label[i] = sup_by_label[i] + [train_images[ind]]

    for i in range(10):
      sup_by_label[i] = np.asarray(sup_by_label[i])

    sup_by_label = np.asarray(sup_by_label)


  add_random_samples = 20
  if add_random_samples > 0:
    rng = np.random.RandomState()
    indices = rng.choice(len(train_images), add_random_samples, False)

    for i in indices:
      l = train_labels[i]
      sup_by_label[l] = np.vstack([sup_by_label[l], [train_images[i]]])


  graph = tf.Graph()
  with graph.as_default():
    model = semisup.SemisupModel(semisup.architectures.mnist_model, NUM_LABELS,
                                 IMAGE_SHAPE)

    # Set up inputs.
    t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                             FLAGS.unsup_batch_size)
    t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
        sup_by_label, FLAGS.sup_per_batch)

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(t_sup_images)
    t_unsup_emb = model.image_to_embedding(t_unsup_images)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    # Add losses.
    model.add_semisup_loss(
        t_sup_emb, t_unsup_emb, t_sup_labels, visit_weight=FLAGS.visit_weight)
    model.add_logit_loss(t_sup_logit, t_sup_labels)

    t_learning_rate = tf.train.exponential_decay(
        FLAGS.learning_rate,
        model.step,
        FLAGS.decay_steps,
        FLAGS.decay_factor,
        staircase=True)
    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    tf.global_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for step in range(FLAGS.max_steps):
      _, summaries = sess.run([train_op, summary_op])
      if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
        print('Step: %d' % step)
        test_pred = model.classify(test_images).argmax(-1)
        conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
        test_err = (test_labels != test_pred).mean() * 100
        print(conf_mtx)
        print('Test error: %.2f %%' % test_err)
        print()

        test_summary = tf.Summary(
            value=[tf.Summary.Value(
                tag='Test Err', simple_value=test_err)])

        summary_writer.add_summary(summaries, step)
        summary_writer.add_summary(test_summary, step)

        saver.save(sess, FLAGS.logdir, model.step)

    coord.request_stop()
    coord.join(threads)
Ejemplo n.º 14
0
def main(_):
    # Get dataset-related toolbox.
    dataset_tools = import_module('tools.' + FLAGS.dataset)
    architecture = getattr(semisup.architectures, FLAGS.architecture)

    nb_words = None
    if architecture == semisup.architectures.cnn_sentiment_model:
        with open(FLAGS.dictionaries, 'rb') as fin:
            tokenizer, encoder, nb_words, seq_len = pickle.load(fin)

        dataset_tools.IMAGE_SHAPE = [seq_len]
        dataset_tools.NUM_LABELS = len(encoder.classes_)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    test_images, test_labels = dataset_tools.get_data('test')

    graph = tf.Graph()
    with graph.as_default():

        # Set up input pipeline.
        image, label = tf.train.slice_input_producer(
            [test_images, test_labels])
        images, labels = tf.train.batch([image, label],
                                        batch_size=FLAGS.eval_batch_size)
        images = tf.cast(images, tf.float32)
        labels = tf.cast(labels, tf.int64)

        # Reshape if necessary.
        if FLAGS.new_size > 0:
            new_shape = [FLAGS.new_size, FLAGS.new_size, 3]
        else:
            new_shape = None

        # Create function that defines the network.
        if architecture == semisup.architectures.cnn_sentiment_model:

            dim = None
            if FLAGS.w2v is not None:
                path = FLAGS.w2v.split(',')[0]
                tf.logging.info('Loading word2vec: {}'.format(path))
                w2v = KeyedVectors.load_word2vec_format(
                    path, unicode_errors='ignore')
                dim = w2v.vector_size

            if dim is None:
                dim = FLAGS.word_embedding_dim

            model_function = partial(
                architecture,
                nb_words=nb_words,
                embedding_dim=dim,
                static_embedding=1,
                embedding_weights=None,
                ############################################################
                img_shape=image_shape,
                emb_size=FLAGS.emb_size)
        else:
            model_function = partial(architecture,
                                     is_training=False,
                                     new_shape=new_shape,
                                     img_shape=image_shape,
                                     augmentation_function=None,
                                     image_summary=False,
                                     emb_size=FLAGS.emb_size)

        # Set up semisup model.
        model = semisup.SemisupModel(model_function,
                                     num_labels,
                                     image_shape,
                                     test_in=images)

        # Add moving average variables.
        for var in tf.get_collection('moving_vars'):
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
        for var in slim.get_model_variables():
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)

        # Get prediction tensor from semisup model.
        predictions = tf.argmax(model.test_logit, 1)

        # Accuracy metric for summaries.
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy':
            slim.metrics.streaming_accuracy(predictions, labels),
        })
        for name, value in names_to_values.items():
            tf.summary.scalar(name, value)

        # Run the actual evaluation loop.
        num_batches = math.ceil(
            len(test_labels) / float(FLAGS.eval_batch_size))

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.logdir + '/train',
            logdir=FLAGS.logdir + '/eval',
            num_evals=num_batches,
            eval_op=[v for v in names_to_updates.values()],
            eval_interval_secs=FLAGS.eval_interval_secs,
            session_config=config,
            timeout=FLAGS.timeout)
Ejemplo n.º 15
0
def main(argv):
    del argv

    # Load data.
    dataset_tools = import_module('tools.' + FLAGS.dataset)
    train_images, train_labels = dataset_tools.get_data('train')
    if FLAGS.target_dataset is not None:
        target_dataset_tools = import_module('tools.' + FLAGS.target_dataset)
        train_images_unlabeled, _ = target_dataset_tools.get_data(
            FLAGS.target_dataset_split)
    else:
        train_images_unlabeled, _ = dataset_tools.get_data('unlabeled')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled, None,
                                                  FLAGS.unsup_batch_size)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) generalize augmentation
                def _random_invert(inputs, _):
                    randu = tf.random_uniform(
                        shape=[FLAGS.sup_per_batch * num_labels],
                        minval=0.,
                        maxval=1.,
                        dtype=tf.float32)
                    randu = tf.cast(tf.less(randu, 0.5), tf.float32)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    inputs = tf.cast(inputs, tf.float32)
                    return tf.abs(inputs - 255 * randu)

                augmentation_function = _random_invert
            else:
                augmentation_function = None

            # Create function that defines the network.
            model_function = partial(
                architecture,
                new_shape=new_shape,
                img_shape=image_shape,
                augmentation_function=augmentation_function,
                batch_norm_decay=FLAGS.batch_norm_decay,
                emb_size=FLAGS.emb_size)

            # Set up semisup model.
            model = semisup.SemisupModel(model_function, num_labels,
                                         image_shape)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)
            t_unsup_emb = model.image_to_embedding(t_unsup_images)

            # Add virtual embeddings.
            if FLAGS.virtual_embeddings:
                t_sup_emb = tf.concat(0, [
                    t_sup_emb,
                    semisup.create_virt_emb(FLAGS.virtual_embeddings,
                                            FLAGS.emb_size)
                ])

                if not FLAGS.remove_classes:
                    # need to add additional labels for virtual embeddings
                    t_sup_labels = tf.concat(0, [
                        t_sup_labels,
                        (num_labels +
                         tf.range(1, FLAGS.virtual_embeddings + 1, tf.int64)) *
                        tf.ones([FLAGS.virtual_embeddings], tf.int64)
                    ])

            t_sup_logit = model.embedding_to_logit(t_sup_emb)

            # Add losses.
            visit_weight_envelope_steps = (FLAGS.walker_weight_envelope_steps
                                           if FLAGS.visit_weight_envelope_steps
                                           == -1 else
                                           FLAGS.visit_weight_envelope_steps)
            visit_weight_envelope_delay = (FLAGS.walker_weight_envelope_delay
                                           if FLAGS.visit_weight_envelope_delay
                                           == -1 else
                                           FLAGS.visit_weight_envelope_delay)
            visit_weight = apply_envelope(
                type=FLAGS.visit_weight_envelope,
                step=model.step,
                final_weight=FLAGS.visit_weight,
                growing_steps=visit_weight_envelope_steps,
                delay=visit_weight_envelope_delay)
            walker_weight = apply_envelope(
                type=FLAGS.walker_weight_envelope,
                step=model.step,
                final_weight=FLAGS.walker_weight,
                growing_steps=FLAGS.walker_weight_envelope_steps,  # pylint:disable=line-too-long
                delay=FLAGS.walker_weight_envelope_delay)
            tf.summary.scalar('Weights_Visit', visit_weight)
            tf.summary.scalar('Weights_Walker', walker_weight)

            if FLAGS.unsup_samples != 0:
                model.add_semisup_loss(t_sup_emb,
                                       t_unsup_emb,
                                       t_sup_labels,
                                       visit_weight=visit_weight,
                                       walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)

            # Set up learning rate
            t_learning_rate = tf.maximum(
                tf.train.exponential_decay(FLAGS.learning_rate,
                                           model.step,
                                           FLAGS.decay_steps,
                                           FLAGS.decay_factor,
                                           staircase=True),
                FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            train_op = model.create_train_op(t_learning_rate)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)  # pylint:disable=line-too-long

            slim.learning.train(
                train_op,
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
            )
Ejemplo n.º 16
0
def main(argv):
    del argv

    # Load data.
    dataset_tools = import_module('tools.' + FLAGS.dataset)
    train_images, train_labels = dataset_tools.get_data('train')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks,
                                                      merge_devices=True)):

            # Set up inputs.
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            if FLAGS.remove_classes:
                t_sup_images = tf.slice(
                    t_sup_images, [0, 0, 0, 0],
                    [FLAGS.sup_per_batch * (
                        num_labels - FLAGS.remove_classes)] +
                    image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) generalize augmentation
                def _random_invert(inputs, _):
                    randu = tf.random_uniform(
                        shape=[FLAGS.sup_per_batch * num_labels], minval=0.,
                        maxval=1.,
                        dtype=tf.float32)
                    randu = tf.cast(tf.less(randu, 0.5), tf.float32)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    inputs = tf.cast(inputs, tf.float32)
                    return tf.abs(inputs - 255 * randu)

                augmentation_function = _random_invert
            else:
                augmentation_function = None

            # Create function that defines the network.
            model_function = partial(
                architecture,
                new_shape=new_shape,
                img_shape=image_shape,
                augmentation_function=augmentation_function,
                batch_norm_decay=FLAGS.batch_norm_decay,
                emb_size=FLAGS.emb_size)

            # Set up semisup model.
            model = semisup.SemisupModel(model_function, num_labels,
                                         image_shape)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)

            t_sup_logit = model.embedding_to_logit(t_sup_emb)

            # Add losses.
            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)

            variables_to_train = [v for v in tf.trainable_variables() if v.name.startswith('net/fully_connected')]
            for v in variables_to_train:
                print(v.name, v.shape)

            # Set up learning rate
            t_learning_rate = tf.maximum(
                tf.train.exponential_decay(
                    FLAGS.learning_rate,
                    model.step,
                    FLAGS.decay_steps,
                    FLAGS.decay_factor,
                    staircase=True),
                FLAGS.minimum_learning_rate)
            total_loss = tf.losses.get_total_loss()
            optimizer = tf.train.AdamOptimizer(t_learning_rate)

            variables_to_restore = None
            restore_ckpt = None
            if FLAGS.pretrained:
              variables_to_restore = [v for v in tf.trainable_variables()
                                      if not (v.name.startswith('net/fully_connected'))]
              restore_ckpt = FLAGS.pretrained
              for v in variables_to_restore:
                print(v.name)
            learning.train(graph, FLAGS.logdir,
                           total_loss, optimizer,
                           variables_to_train, model.step,
                           num_steps=FLAGS.max_steps, log_interval=20,
                           summary_interval=100, snapshot_interval=5000,
                           variables_to_restore=variables_to_restore, restore_ckpt=restore_ckpt)
            return

            # Create training operation and start the actual training loop.
            train_op = model.create_train_op(t_learning_rate)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours)  # pylint:disable=line-too-long

            slim.learning.train(
                train_op,
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
            )
def main(_):
    if FLAGS.logdir is not None:
        if FLAGS.taskid is not None:
            FLAGS.logdir = FLAGS.logdir + '/t_' + str(FLAGS.taskid)
        else:
            FLAGS.logdir = FLAGS.logdir + '/t_' + str(random.randint(0, 99999))

    dataset_tools = import_module('tools.' + FLAGS.dataset)

    NUM_LABELS = dataset_tools.NUM_LABELS
    num_labels = NUM_LABELS
    IMAGE_SHAPE = dataset_tools.IMAGE_SHAPE
    image_shape = IMAGE_SHAPE

    train_images, train_labels_svm = dataset_tools.get_data(
        'train')  # no train labels nowhere
    test_images, test_labels = dataset_tools.get_data('test')

    if FLAGS.zero_fact < 1:
        # exclude a random set of zeros (not at the end, then there would be many batches without zeros)
        keep = np.ones(len(train_labels_svm), dtype=bool)
        zero_indices = np.where((train_labels_svm == 0))[0]

        remove = np.random.uniform(0, 1, len(zero_indices))
        zero_indices_to_remove = zero_indices[remove > FLAGS.zero_fact]

        keep[zero_indices_to_remove] = False

        train_images = train_images[keep]
        train_labels_svm = train_labels_svm[keep]

        print(
            'using only a fraction of zeros, resulting in the following shape:',
            train_images.shape)

    if FLAGS.num_unlabeled_images > 0:
        unlabeled_train_images, _ = dataset_tools.get_data(
            'unlabeled', max_num=np.min([FLAGS.num_unlabeled_images, 50000]))
        train_images = np.vstack([train_images, unlabeled_train_images])

    if FLAGS.normalize_input:
        train_images = (train_images - 128.) / 128.
        test_images = (test_images - 128.) / 128.

    if FLAGS.use_test:
        train_images = np.vstack([train_images, test_images])
        train_labels_svm = np.hstack([train_labels_svm, test_labels])

    #if FLAGS.dataset == 'svhn' and FLAGS.architecture == 'resnet_cifar_model':
    #  FLAGS.emb_size = 64

    image_shape_crop = image_shape
    c_test_imgs = test_images
    c_train_imgs = train_images

    # crop images to some random region. Intuitively, images should belong to the same cluster,
    # even if a part of the image is missing
    # (no padding, because the net could detect padding easily, and match it to other augmented samples that have
    # padding)
    if FLAGS.dataset == 'stl10':
        image_shape_crop = [64, 64, 3]
        c_test_imgs = test_images[:, 16:80, 16:80]
        c_train_imgs = train_images[:, 16:80, 16:80]

    def aug(image):
        return apply_augmentation(image,
                                  target_shape=image_shape_crop,
                                  params=dataset_tools.augmentation_params)

    def random_crop(image):
        image_size = image_shape_crop[0]
        image = tf.random_crop(image, [image_size, image_size, image_shape[2]])

        return image

    graph = tf.Graph()
    with graph.as_default():
        t_images = tf.placeholder("float", shape=[None] + image_shape)

        dataset = Dataset.from_tensor_slices(t_images)
        dataset = dataset.shuffle(
            buffer_size=10000,
            seed=47)  # important, so that we have the same images in both sets

        # parameters for buffering during augmentation. Only influence training speed.
        nt = 8 if FLAGS.volta else 4  # that's not even enough, but there are no more CPUs
        b = 10000

        rf = FLAGS.num_augmented_samples

        augmented_set = dataset
        if FLAGS.shuffle_augmented_samples:
            augmented_set = augmented_set.shuffle(buffer_size=10000, seed=47)

        # get multiple augmented versions of the same image - they should later have similar embeddings
        augmented_set = augmented_set.flat_map(
            lambda x: Dataset.from_tensors(x).repeat(rf))

        augmented_set = augmented_set.map(aug)

        dataset = dataset.map(
            random_crop)  # why apply random crop to the batch B
        dataset = dataset.repeat().batch(FLAGS.unsup_batch_size)
        augmented_set = augmented_set.repeat().batch(FLAGS.unsup_batch_size *
                                                     rf)

        iterator = dataset.make_initializable_iterator()
        reg_iterator = augmented_set.make_initializable_iterator()

        t_unsup_images = iterator.get_next()  # unaugmented image batch A
        t_reg_unsup_images = reg_iterator.get_next()  # augmented image batch B

        model_func = getattr(semisup.architectures, FLAGS.architecture)

        model = semisup.SemisupModel(
            model_func,
            num_labels,
            image_shape_crop,
            optimizer='adam',
            emb_size=FLAGS.emb_size,
            dropout_keep_prob=FLAGS.dropout_keep_prob,
            num_blocks=FLAGS.num_blocks,
            normalize_embeddings=FLAGS.normalize_embeddings,
            beta1=FLAGS.beta1,
            beta2=FLAGS.beta2)

        # initialation of centroid variables
        init_virt = []
        for c in range(
                num_labels
        ):  # init_virt and shape of centroids and repeat 4 times??
            center = np.random.normal(0, 0.3, size=[1, FLAGS.emb_size])
            noise = np.random.uniform(
                -0.01,
                0.01,
                size=[FLAGS.virtual_embeddings_per_class, FLAGS.emb_size])
            centroids = noise + center
            init_virt.extend(centroids)
        # centroid variables (stored as tf.variable)
        t_sup_emb = tf.Variable(tf.cast(np.array(init_virt), tf.float32),
                                name="virtual_centroids")

        t_sup_labels = tf.constant(
            np.concatenate([[i] * FLAGS.virtual_embeddings_per_class
                            for i in range(num_labels)]))

        visit_weight = tf.placeholder("float", shape=[])
        walker_weight = tf.placeholder("float", shape=[])
        t_logit_weight = tf.placeholder("float", shape=[])
        t_trafo_weight = tf.placeholder("float", shape=[])

        t_l1_weight = tf.placeholder("float", shape=[])
        t_norm_weight = tf.placeholder("float", shape=[])
        t_learning_rate = tf.placeholder("float", shape=[])
        t_sat_loss_weight = tf.placeholder("float", shape=[])

        t_unsup_emb = model.image_to_embedding(t_unsup_images)
        t_reg_unsup_emb = model.image_to_embedding(t_reg_unsup_images)

        t_all_unsup_emb = tf.concat([t_unsup_emb, t_reg_unsup_emb], axis=0)
        t_rsup_labels = tf.constant(
            np.concatenate([[i] * rf for i in range(FLAGS.unsup_batch_size)]))

        rwalker_weight = tf.placeholder("float", shape=[])
        rvisit_weight = tf.placeholder("float", shape=[])

        if FLAGS.normalize_embeddings:
            t_sup_logit = model.embedding_to_logit(
                tf.nn.l2_normalize(t_sup_emb, dim=1))
            model.add_semisup_loss(tf.nn.l2_normalize(t_sup_emb, dim=1),
                                   tf.nn.l2_normalize(t_unsup_emb, dim=1),
                                   t_sup_labels,
                                   walker_weight=walker_weight,
                                   visit_weight=visit_weight,
                                   match_scale=FLAGS.scale_match_ab)
            model.reg_loss_aba = model.add_semisup_loss(
                tf.nn.l2_normalize(t_reg_unsup_emb, dim=1),
                tf.nn.l2_normalize(t_unsup_emb, dim=1),
                t_rsup_labels,
                walker_weight=rwalker_weight,
                visit_weight=rvisit_weight,
                match_scale=FLAGS.scale_match_ab,
                est_err=False)

        else:
            t_sup_logit = model.embedding_to_logit(t_sup_emb)
            # loss assoc,c
            model.add_semisup_loss(t_sup_emb,
                                   t_unsup_emb,
                                   t_sup_labels,
                                   walker_weight=walker_weight,
                                   visit_weight=visit_weight,
                                   match_scale=FLAGS.scale_match_ab,
                                   est_err=True,
                                   name='c_association')
            # loss assoc,aug
            model.reg_loss_aba = model.add_semisup_loss(
                t_reg_unsup_emb,
                t_unsup_emb,
                t_rsup_labels,
                walker_weight=rwalker_weight,
                visit_weight=rvisit_weight,
                match_scale=FLAGS.scale_match_ab,
                est_err=False,
                name='aug_association')

        model.add_logit_loss(t_sup_logit, t_sup_labels, weight=t_logit_weight)

        t_reg_unsup_emb_singled = t_reg_unsup_emb[::FLAGS.
                                                  num_augmented_samples]

        t_unsup_logit = model.embedding_to_logit(t_unsup_emb)
        t_reg_unsup_logit = model.embedding_to_logit(t_reg_unsup_emb_singled)

        model.add_sat_loss(t_unsup_logit,
                           t_reg_unsup_logit,
                           weight=t_sat_loss_weight)

        trafo_lc = semisup.NO_FC_COLLECTION if FLAGS.trafo_separate_loss_collection else semisup.LOSSES_COLLECTION

        if FLAGS.trafo_weight > 0:
            # only use a single augmented sample per sample

            t_trafo_loss = model.add_transformation_loss(
                t_unsup_emb,
                t_reg_unsup_emb_singled,
                t_unsup_logit,
                t_reg_unsup_logit,
                FLAGS.unsup_batch_size,
                weight=t_trafo_weight,
                label_smoothing=0,
                loss_collection=trafo_lc)
        else:
            t_trafo_loss = tf.constant(0)

        model.add_emb_regularization(t_all_unsup_emb, weight=t_l1_weight)
        model.add_emb_regularization(t_sup_emb, weight=t_l1_weight)

        # make l2 norm = 3
        model.add_emb_normalization(t_sup_emb,
                                    weight=t_norm_weight,
                                    target=FLAGS.norm_target)
        model.add_emb_normalization(t_all_unsup_emb,
                                    weight=t_norm_weight,
                                    target=FLAGS.norm_target)

        gradient_multipliers = {t_sup_emb: 1}
        [train_op, train_op_sat
         ] = model.create_train_op(t_learning_rate,
                                   gradient_multipliers=gradient_multipliers)

        summary_op = tf.summary.merge_all()
        if FLAGS.logdir is not None:
            summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)
            saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        aaa = sess.run(iterator.initializer,
                       feed_dict={t_images: train_images})
        raaa = sess.run(reg_iterator.initializer,
                        feed_dict={t_images: train_images})

        # optional: init from autoencoder
        if FLAGS.restore_checkpoint is not None:
            # logit fc layer cannot be restored
            def is_main_net(x):
                return 'logit_fc' not in x.name and 'Adam' not in x.name

            variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                          scope='net')
            variables = list(filter(is_main_net, variables))

            restorer = tf.train.Saver(var_list=variables)
            restorer.restore(sess, FLAGS.restore_checkpoint)

        extra_feed_dict = {}

        from numpy.linalg import norm

        reg_warmup_steps = FLAGS.reg_warmup_steps
        logit_weight_ = FLAGS.logit_weight
        rwalker_weight_ = FLAGS.rwalker_weight
        rvisit_weight_ = FLAGS.rvisit_weight
        learning_rate_ = FLAGS.learning_rate
        trafo_weight = FLAGS.trafo_weight

        kmeans_initialized = False

        for step in range(FLAGS.max_steps):
            import time
            start = time.time()
            if FLAGS.init_with_kmeans:
                if FLAGS.kmeans_sat_thresh is not None and not kmeans_initialized or \
                        FLAGS.kmeans_sat_thresh is None and step <= reg_warmup_steps:
                    walker_weight_ = 0
                    visit_weight_ = 0
                    logit_weight_ = 0
                    trafo_weight = 0
                else:
                    walker_weight_ = FLAGS.walker_weight
                    visit_weight_ = FLAGS.visit_weight_base
                    logit_weight_ = FLAGS.logit_weight
                    trafo_weight = FLAGS.trafo_weight
            else:
                walker_weight_ = apply_envelope("log", step,
                                                FLAGS.walker_weight,
                                                reg_warmup_steps, 0)
                visit_weight_ = apply_envelope("log", step,
                                               FLAGS.visit_weight_base,
                                               reg_warmup_steps, 0)

            feed_dict = {
                rwalker_weight:
                rwalker_weight_ * FLAGS.reg_association_weight,
                rvisit_weight:
                rvisit_weight_ * FLAGS.reg_association_weight,
                walker_weight:
                walker_weight_ * FLAGS.cluster_association_weight,
                visit_weight:
                visit_weight_ * FLAGS.cluster_association_weight,
                t_l1_weight:
                FLAGS.l1_weight,
                t_norm_weight:
                FLAGS.norm_weight,
                t_logit_weight:
                logit_weight_,
                t_trafo_weight:
                trafo_weight,
                t_sat_loss_weight:
                0,
                t_learning_rate:
                1e-6 + apply_envelope("log", step, learning_rate_,
                                      FLAGS.warmup_steps, 0)
            }
            _, sat_loss, train_loss, summaries, centroids, unsup_emb, reg_unsup_emb, estimated_error, p_ab, p_ba, p_aba, \
            reg_loss, trafo_loss = sess.run(
                    [train_op, train_op_sat, model.train_loss, summary_op, t_sup_emb, t_unsup_emb, t_reg_unsup_emb,
                     model.estimate_error, model.p_ab,
                     model.p_ba, model.p_aba, model.reg_loss_aba, t_trafo_loss], {**extra_feed_dict, **feed_dict})

            if FLAGS.kmeans_sat_thresh is not None and step % 200 == 0 and not kmeans_initialized:
                sat_score = semisup.calc_sat_score(unsup_emb, reg_unsup_emb)

                if sat_score > FLAGS.kmeans_sat_thresh:
                    print('initializing with kmeans', step, sat_score)
                    FLAGS.init_with_kmeans = True
                    kmeans_initialized = True
                    reg_warmup_steps = step  # -> jump to next if clause

            if FLAGS.init_with_kmeans and step == reg_warmup_steps:
                # do kmeans, initialize with kmeans
                embs = model.calc_embedding(c_train_imgs, model.test_emb, sess,
                                            extra_feed_dict)

                kmeans = semisup.KMeans(n_clusters=num_labels,
                                        random_state=0).fit(embs)

                init_virt = []
                noise = 0.0001
                for c in range(num_labels):
                    center = kmeans.cluster_centers_[c]
                    noise = np.random.uniform(
                        -noise,
                        noise,
                        size=[
                            FLAGS.virtual_embeddings_per_class, FLAGS.emb_size
                        ])
                    centroids = noise + center
                    init_virt.extend(centroids)

                # init with K-Means
                assign_op = t_sup_emb.assign(np.array(init_virt))
                sess.run(assign_op)
                model.reset_optimizer(sess)

                rwalker_weight_ *= FLAGS.reg_decay_factor
                rvisit_weight_ *= FLAGS.reg_decay_factor

            if FLAGS.svm_test_interval is not None and step % FLAGS.svm_test_interval == 0 and step > 0:
                svm_test_score, _ = model.train_and_eval_svm(c_train_imgs,
                                                             train_labels_svm,
                                                             c_test_imgs,
                                                             test_labels,
                                                             sess,
                                                             num_samples=5000)
                print('svm score:', svm_test_score)
                test_pred = model.classify(c_test_imgs, sess)
                train_pred = model.classify(c_train_imgs, sess)
                svm_test_score, _ = model.train_and_eval_svm_on_preds(
                    train_pred,
                    train_labels_svm,
                    test_pred,
                    test_labels,
                    sess,
                    num_samples=5000)
                print('svm score on logits:', svm_test_score)

            if step % FLAGS.decay_steps == 0 and step > 0:
                learning_rate_ = learning_rate_ * FLAGS.decay_factor

            if step == 0 or (step +
                             1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                print('trafo loss', trafo_loss)
                print('reg loss', reg_loss)
                print('Time for step', time.time() - start)
                test_pred = model.classify(c_test_imgs, sess,
                                           extra_feed_dict).argmax(-1)

                nmi = semisup.calc_nmi(test_pred, test_labels)

                conf_mtx, score = semisup.calc_correct_logit_score(
                    test_pred, test_labels, num_labels)
                print('Confusion matrix')
                print(conf_mtx)
                print('Test error: %.2f %%' % (100 - score * 100))
                print('Test NMI: %.2f %%' % (nmi * 100))
                print('Train loss: %.2f ' % train_loss)
                print('Train loss no fc: %.2f ' % sat_loss)
                print('Estimated Accuracy: %.2f ' % estimated_error)

                sat_score = semisup.calc_sat_score(unsup_emb, reg_unsup_emb)
                print('sat accuracy', sat_score)

                embs = model.calc_embedding(c_test_imgs, model.test_emb, sess,
                                            extra_feed_dict)

                c_n = norm(centroids, axis=1, ord=2)
                e_n = norm(embs[0:100], axis=1, ord=2)
                print('centroid norm', np.mean(c_n))
                print('embedding norm', np.mean(e_n))

                k_conf_mtx, k_score = semisup.do_kmeans(
                    embs, test_labels, num_labels)
                print('k means Confusion matrix')

                print(k_conf_mtx)
                print(
                    'k means score:',
                    k_score)  # sometimes that kmeans is better than the logits

                if FLAGS.logdir is not None:
                    sum_values = {
                        'test score': score,
                        'reg loss': reg_loss,
                        'centroid norm': np.mean(c_n),
                        'embedding norm': np.mean(c_n),
                        'k means score': k_score
                    }

                    summary_writer.add_summary(summaries, step)

                    for key, value in sum_values.items():
                        summary = tf.Summary(value=[
                            tf.Summary.Value(tag=key, simple_value=value)
                        ])
                        summary_writer.add_summary(summary, step)

                # early stopping to save some time
                if step == 34999 and score < 0.45:
                    break
                if step == 14999 and score < 0.225:
                    break

                if dataset == 'mnist' and step == 6999 and score < 0.25:
                    break

        svm_test_score, _ = model.train_and_eval_svm(c_train_imgs,
                                                     train_labels_svm,
                                                     c_test_imgs,
                                                     test_labels,
                                                     sess,
                                                     num_samples=10000)

        if FLAGS.logdir is not None:
            path = saver.save(sess, FLAGS.logdir, model.step)
            print('@@model_path:%s' % path)

        print('FINAL RESULTS:')
        print(conf_mtx)
        print('Test error: %.2f %%' % (100 - score * 100))
        print('final_score', score)

        print('@@test_error:%.4f' % score)
        print('@@train_loss:%.4f' % train_loss)
        print('@@reg_loss:%.4f' % reg_loss)
        print('@@estimated_error:%.4f' % estimated_error)
        print('@@centroid_norm:%.4f' % np.mean(c_n))
        print('@@emb_norm:%.4f' % np.mean(e_n))
        print('@@k_score:%.4f' % k_score)
        print('@@svm_score:%.4f' % svm_test_score)
def main(_):
    from inception.imagenet_data import ImagenetData
    from inception import image_processing
    dataset = ImagenetData(subset='train')
    assert dataset.data_files()
    NUM_LABELS = dataset.num_classes() + 1
    IMAGE_SHAPE = [FLAGS.image_size, FLAGS.image_size, 3]
    graph = tf.Graph()
    with graph.as_default():
        model = semisup.SemisupModel(inception_model, NUM_LABELS, IMAGE_SHAPE)

        # t_sup_images, t_sup_labels = tools.get_data('train')
        # t_unsup_images, _ = tools.get_data('unlabeled')

        images, labels = image_processing.batch_inputs(
            dataset,
            32,
            train=True,
            num_preprocess_threads=FLAGS.num_readers,
            num_readers=FLAGS.num_readers)

        t_sup_images, t_sup_labels = tf.train.batch(
            [images, labels],
            batch_size=FLAGS.sup_batch_size,
            enqueue_many=True,
            num_threads=FLAGS.num_readers,
            capacity=1000 + 3 * FLAGS.sup_batch_size,
        )

        t_unsup_images, t_unsup_labels = tf.train.batch(
            [images, labels],
            batch_size=FLAGS.sup_batch_size,
            enqueue_many=True,
            num_threads=FLAGS.num_readers,
            capacity=1000 + 3 * FLAGS.sup_batch_size,
        )

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_unsup_emb = model.image_to_embedding(t_unsup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        model.add_semisup_loss(t_sup_emb,
                               t_unsup_emb,
                               t_sup_labels,
                               visit_weight=FLAGS.visit_weight)

        model.add_logit_loss(t_sup_logit, t_sup_labels)

        t_learning_rate = tf.maximum(
            tf.train.exponential_decay(FLAGS.learning_rate,
                                       model.step,
                                       FLAGS.decay_steps,
                                       FLAGS.decay_factor,
                                       staircase=True),
            FLAGS.minimum_learning_rate)

        # Create training operation and start the actual training loop.
        train_op = model.create_train_op(t_learning_rate)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        slim.learning.train(train_op,
                            logdir=FLAGS.logdir,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs,
                            master=FLAGS.master,
                            is_chief=(FLAGS.task == 0),
                            startup_delay_steps=(FLAGS.task * 20),
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            session_config=config)
Ejemplo n.º 19
0
def main(_):
    train_images, train_labels = mnist_tools.get_data('train')
    test_images, test_labels = mnist_tools.get_data('test')

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else np.random.randint(
        0, 1000)
    print('Seed:', seed)
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, NUM_LABELS,
                                           seed)

    # add twos, fours, and sixes
    if FLAGS.testadd:
        num_to_add = 10
        for i in range(10):
            items = np.where(train_labels == i)[0]
            inds = np.random.choice(len(items), num_to_add, replace=False)
            sup_by_label[i] = np.vstack(
                [sup_by_label[i], train_images[items[inds]]])

    add_random_samples = 0
    if add_random_samples > 0:
        rng = np.random.RandomState()
        indices = rng.choice(len(train_images), add_random_samples, False)

        for i in indices:
            l = train_labels[i]
            sup_by_label[l] = np.vstack([sup_by_label[l], [train_images[i]]])
            print(l)

    graph = tf.Graph()
    with graph.as_default():
        model = semisup.SemisupModel(semisup.architectures.mnist_model,
                                     NUM_LABELS,
                                     IMAGE_SHAPE,
                                     dropout_keep_prob=0.8)

        # Set up inputs.
        if FLAGS.random_batches:
            sup_lbls = np.asarray(
                np.hstack([
                    np.ones(len(i)) * ind for ind, i in enumerate(sup_by_label)
                ]), np.int)
            sup_images = np.vstack(sup_by_label)
            t_sup_images, t_sup_labels = semisup.create_input(
                sup_images, sup_lbls, FLAGS.sup_per_batch * NUM_LABELS)
        else:
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        if FLAGS.semisup:
            if FLAGS.equal_cls_unsup:
                allimgs_bylabel = semisup.sample_by_label(
                    train_images, train_labels, 5000, NUM_LABELS, seed)
                t_unsup_images, _ = semisup.create_per_class_inputs(
                    allimgs_bylabel, FLAGS.sup_per_batch)
            else:
                t_unsup_images, _ = semisup.create_input(
                    train_images, train_labels, FLAGS.unsup_batch_size)

            t_unsup_emb = model.image_to_embedding(t_unsup_images)
            model.add_semisup_loss(t_sup_emb,
                                   t_unsup_emb,
                                   t_sup_labels,
                                   walker_weight=FLAGS.walker_weight,
                                   visit_weight=FLAGS.visit_weight,
                                   proximity_weight=FLAGS.proximity_weight)
        model.add_logit_loss(t_sup_logit, t_sup_labels)

        t_learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                                     model.step,
                                                     FLAGS.decay_steps,
                                                     FLAGS.decay_factor,
                                                     staircase=True)
        train_op = model.create_train_op(t_learning_rate)
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for step in range(FLAGS.max_steps):
            _, summaries = sess.run([train_op, summary_op])
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

                saver.save(sess, FLAGS.logdir, model.step)

        coord.request_stop()
        coord.join(threads)
def main(_):
    if FLAGS.logdir is not None:
        FLAGS.logdir = FLAGS.logdir + '/t_mnist_eval'
        try:
            shutil.rmtree(FLAGS.logdir)
        except OSError as e:
            print ("Error: %s - %s." % (e.filename, e.strerror))

    train_images, train_labels = mnist_tools.get_data('train')
    test_images, test_labels = mnist_tools.get_data('test')

    # Sample labeled training subset.
    if FLAGS.sup_seed >= 0:
      seed = FLAGS.sup_seed
    elif FLAGS.sup_seed == -2:
      seed = FLAGS.sup_per_class
    else:
      seed = np.random.randint(0, 1000)

    print('Seed:', seed)
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, NUM_LABELS, seed)


    graph = tf.Graph()
    with graph.as_default():
        model = semisup.SemisupModel(semisup.architectures.mnist_model, NUM_LABELS,
                                     IMAGE_SHAPE, dropout_keep_prob=FLAGS.dropout_keep_prob)

        # Set up inputs.
        t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                    sup_by_label, FLAGS.sup_per_batch)

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        if FLAGS.semisup:
            t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                                         FLAGS.unsup_batch_size)

            t_unsup_emb = model.image_to_embedding(t_unsup_images)
            model.add_semisup_loss(
                    t_sup_emb, t_unsup_emb, t_sup_labels,
                    walker_weight=FLAGS.walker_weight, visit_weight=FLAGS.visit_weight)

            #model.add_emb_regularization(t_unsup_emb, weight=FLAGS.l1_weight)

        model.add_logit_loss(t_sup_logit, t_sup_labels, weight=FLAGS.logit_weight)

        #model.add_emb_regularization(t_sup_emb, weight=FLAGS.l1_weight)

        t_learning_rate = tf.placeholder("float", shape=[])

        train_op = model.create_train_op(t_learning_rate)

        summary_op = tf.summary.merge_all()
        if FLAGS.logdir is not None:
            summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)
            saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        learning_rate_ = FLAGS.learning_rate

        for step in range(FLAGS.max_steps):
            lr = learning_rate_
            if step < FLAGS.warmup_steps:
                lr = 1e-6 + semisup.apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)

            _, summaries = sess.run([train_op, summary_op], {
              t_learning_rate: lr
            })

            sys.stderr.write("\rstep: %d" % step)
            sys.stdout.flush()
    
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('\nStep: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)

    
                if FLAGS.logdir is not None:
                    sum_values = {
                    'Test error': test_err
                }
                summary_writer.add_summary(summaries, step)
                for key, value in sum_values.items():
                    summary = tf.Summary(
                            value=[tf.Summary.Value(tag=key, simple_value=value)])
                    summary_writer.add_summary(summary, step)

            if step % FLAGS.decay_steps == 0 and step > 0:
                learning_rate_ = learning_rate_ * FLAGS.decay_factor

        coord.request_stop()
        coord.join(threads)

    print('FINAL RESULTS:')
    print('Test error: %.2f %%' % (test_err))
    print('final_score', 1 - test_err/100)

    print('@@test_error:%.4f' % (test_err/100))
    print('@@train_loss:%.4f' % 0)
    print('@@reg_loss:%.4f' % 0)
    print('@@estimated_error:%.4f' % 0)
    print('@@centroid_norm:%.4f' % 0)
    print('@@emb_norm:%.4f' % 0)
    print('@@k_score:%.4f' % 0)
    print('@@svm_score:%.4f' % 0)
Ejemplo n.º 21
0
def main(_):
    # Get dataset-related toolbox.
    dataset_tools = import_module('tools.' + FLAGS.dataset)
    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    test_images, test_labels = dataset_tools.get_data('test')

    graph = tf.Graph()
    with graph.as_default():

        # Set up input pipeline.
        image, label = tf.train.slice_input_producer(
            [test_images, test_labels])
        images, labels = tf.train.batch([image, label],
                                        batch_size=FLAGS.eval_batch_size)
        images = tf.cast(images, tf.float32)
        labels = tf.cast(labels, tf.int64)

        # Reshape if necessary.
        if FLAGS.new_size > 0:
            new_shape = [FLAGS.new_size, FLAGS.new_size, 3]
        else:
            new_shape = None

        # Create function that defines the network.
        model_function = partial(architecture,
                                 is_training=False,
                                 new_shape=new_shape,
                                 img_shape=image_shape,
                                 augmentation_function=None,
                                 image_summary=False,
                                 emb_size=FLAGS.emb_size)

        # Set up semisup model.
        model = semisup.SemisupModel(model_function,
                                     num_labels,
                                     image_shape,
                                     test_in=images)

        # Add moving average variables.
        for var in tf.get_collection('moving_vars'):
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
        for var in slim.get_model_variables():
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)

        # Get prediction tensor from semisup model.
        predictions = tf.argmax(model.test_logit, 1)

        # Accuracy metric for summaries.
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy':
            slim.metrics.streaming_accuracy(predictions, labels),
        })
        for name, value in names_to_values.iteritems():
            tf.summary.scalar(name, value)

        # Run the actual evaluation loop.
        num_batches = math.ceil(
            len(test_labels) / float(FLAGS.eval_batch_size))

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.logdir + '/train',
            logdir=FLAGS.logdir + '/eval',
            num_evals=num_batches,
            eval_op=names_to_updates.values(),
            eval_interval_secs=FLAGS.eval_interval_secs,
            session_config=config,
            timeout=FLAGS.timeout)
Ejemplo n.º 22
0
def main(_):

# Sample labeled training subset.

  train_images, train_labels = data.load_training_data(FLAGS.cifar)
  test_images, test_labels = data.load_test_data(FLAGS.cifar)

  print(train_images.shape, train_labels.shape)

  # Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else np.random.randint(0, 1000)
  print('Seed:', seed)
  sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                         FLAGS.sup_per_class, NUM_LABELS, seed)

  graph = tf.Graph()
  with graph.as_default():

    def augment(image):
        # image_size = 28
        # image = tf.image.resize_image_with_crop_or_pad(
        #    image, image_size+4, image_size+4)
        # image = tf.random_crop(image, [image_size, image_size, 3])
        image = tf.image.random_flip_left_right(image)
        # Brightness/saturation/constrast provides small gains .2%~.5% on cifar.
        # image = tf.image.random_brightness(image, max_delta=63. / 255.)
        # image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        # image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
        return image

    model_func = getattr(semisup.architectures, FLAGS.model)
    model = semisup.SemisupModel(model_func, NUM_LABELS,
                                 IMAGE_SHAPE, emb_size=256, dropout_keep_prob=FLAGS.dropout_keep_prob,
                                 augmentation_function=augment)


    if FLAGS.random_batches:
      sup_lbls = np.asarray(np.hstack([np.ones(len(i)) * ind for ind, i in enumerate(sup_by_label)]), np.int)
      sup_images = np.vstack(sup_by_label)
      batch_size = FLAGS.sup_per_batch * NUM_LABELS
      t_sup_images, t_sup_labels = semisup.create_input(sup_images, sup_lbls, batch_size)
    else:
      t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
        sup_by_label, FLAGS.sup_per_batch)

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(t_sup_images)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    # Add losses.
    if FLAGS.semisup:
      if FLAGS.equal_cls_unsup:
        allimgs_bylabel = semisup.sample_by_label(train_images, train_labels,
                                                  int(NUM_TRAIN_IMAGES / NUM_LABELS), NUM_LABELS, seed)
        t_unsup_images, _ = semisup.create_per_class_inputs(
          allimgs_bylabel, FLAGS.sup_per_batch)
      else:
        t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                                 FLAGS.unsup_batch_size)

      t_unsup_emb = model.image_to_embedding(t_unsup_images)
      proximity_weight = tf.placeholder("float", shape=[])
      visit_weight = tf.placeholder("float", shape=[])
      walker_weight = tf.placeholder("float", shape=[])

      model.add_semisup_loss(
        t_sup_emb, t_unsup_emb, t_sup_labels,
        walker_weight=walker_weight, visit_weight=visit_weight, proximity_weight=proximity_weight)
    model.add_logit_loss(t_sup_logit, t_sup_labels)

    t_learning_rate = tf.train.exponential_decay(
      FLAGS.learning_rate,
      model.step,
      FLAGS.decay_steps,
      FLAGS.decay_factor,
      staircase=True)
    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()
  with tf.Session(graph=graph) as sess:
    tf.global_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for step in range(FLAGS.max_steps):
      feed_dict={}
      if FLAGS.semisup:

        if FLAGS.proximity_loss:
          feed_dict = {
            walker_weight:  FLAGS.walker_weight,#0.1 + apply_envelope("lin", step, FLAGS.walker_weight, 500, FLAGS.decay_steps),
            visit_weight: 0,
            proximity_weight: FLAGS.visit_weight#0.1 + apply_envelope("lin", step,FLAGS.visit_weight, 500, FLAGS.decay_steps),
          #t_logit_weight: FLAGS.logit_weight,
          #t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
          }
        else:
          feed_dict =  {
            walker_weight:  FLAGS.walker_weight,#0.1 + apply_envelope("lin", step, FLAGS.walker_weight, 500, FLAGS.decay_steps),
            visit_weight: FLAGS.visit_weight, #0.1 + apply_envelope("lin", step,FLAGS.visit_weight, 500, FLAGS.decay_steps),
            proximity_weight: 0
          #t_logit_weight: FLAGS.logit_weight,
          #t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
          }

      _, summaries, train_loss = sess.run([train_op, summary_op, model.train_loss], feed_dict)


      if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
        print('Step: %d' % step)
        test_pred = model.classify(test_images, sess).argmax(-1)
        conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
        test_err = (test_labels != test_pred).mean() * 100
        print(conf_mtx)
        print('Test error: %.2f %%' % test_err)
        print()

        test_summary = tf.Summary(
            value=[tf.Summary.Value(
                tag='Test Err', simple_value=test_err)])

        summary_writer.add_summary(summaries, step)
        summary_writer.add_summary(test_summary, step)

        #saver.save(sess, FLAGS.logdir, model.step)

    coord.request_stop()
    coord.join(threads)