Пример #1
0
    def train_warmup():
        """ if there are few samples, warmup at first"""
        sup_images_iterator, sup_labels_iterator, input_sess, input_threads, input_coord = init_iterators(
        )

        model.reset_optimizer(sess)

        for step in range(FLAGS.max_steps):
            unsup_images, _ = sess.run(unsup_images_iterator)
            si, sl = input_sess.run([sup_images_iterator, sup_labels_iterator])

            _, summaries, train_loss = sess.run(
                [train_op, summary_op, model.train_loss],
                {
                    t_unsup_images:
                    unsup_images,
                    walker_weight:
                    FLAGS.walker_weight,
                    proximity_weight:
                    0,
                    visit_weight:
                    0.1 +
                    apply_envelope("lin", step, 0.4, FLAGS.warmup_steps, 0),
                    t_logit_weight:
                    0.5,  # FLAGS.logit_weight,
                    t_l1_weight:
                    FLAGS.l1_weight,
                    t_sup_images:
                    si,
                    t_sup_labels:
                    sl,
                    t_learning_rate:
                    5e-5 + apply_envelope("log", step, FLAGS.learning_rate,
                                          FLAGS.warmup_steps, 0)
                })

            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print('Train loss: %.2f ' % train_loss)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

        input_coord.request_stop()
        input_coord.join(input_threads)
        input_sess.close()
Пример #2
0
    def train_finetune(lr=0.001, steps=FLAGS.max_steps):
        sup_images_iterator, sup_labels_iterator, input_sess, input_threads, input_coord = init_iterators(
        )

        model.reset_optimizer(sess)
        for step in range(int(steps)):
            unsup_images, _ = sess.run(unsup_images_iterator)
            si, sl = input_sess.run([sup_images_iterator, sup_labels_iterator])
            _, summaries, train_loss = sess.run(
                [train_op, summary_op, model.train_loss], {
                    t_unsup_images: unsup_images,
                    walker_weight: 1,
                    proximity_weight: 0,
                    visit_weight: 1,
                    t_l1_weight: 0,
                    t_logit_weight: 1,
                    t_learning_rate: lr,
                    t_sup_images: si,
                    t_sup_labels: sl
                })
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print('Train loss: %.2f ' % train_loss)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

        input_coord.request_stop()
        input_coord.join(input_threads)
        input_sess.close()
Пример #3
0
def main(_):
    train_images, train_labels = mnist_tools.get_data('train')
    test_images, test_labels = mnist_tools.get_data('test')

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, NUM_LABELS,
                                           seed)

    graph = tf.Graph()
    with graph.as_default():
        model = semisup.SemisupModel(semisup.architectures.mnist_model,
                                     NUM_LABELS, IMAGE_SHAPE)

        # Set up inputs.
        t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                                 FLAGS.unsup_batch_size)
        t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
            sup_by_label, FLAGS.sup_per_batch)

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_unsup_emb = model.image_to_embedding(t_unsup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        model.add_semisup_loss(t_sup_emb,
                               t_unsup_emb,
                               t_sup_labels,
                               visit_weight=FLAGS.visit_weight)
        model.add_logit_loss(t_sup_logit, t_sup_labels)

        t_learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                                     model.step,
                                                     FLAGS.decay_steps,
                                                     FLAGS.decay_factor,
                                                     staircase=True)
        train_op = model.create_train_op(t_learning_rate)
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for step in xrange(FLAGS.max_steps):
            _, summaries = sess.run([train_op, summary_op])
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

                saver.save(sess, FLAGS.logdir, model.step)

        coord.request_stop()
        coord.join(threads)
Пример #4
0
def main(_):
    train_images, train_labels = mnist_tools.get_data('train')
    test_images, test_labels = mnist_tools.get_data('test')

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else np.random.randint(
        0, 1000)
    print('Seed:', seed)
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, NUM_LABELS,
                                           seed)

    # add twos, fours, and sixes
    if FLAGS.testadd:
        num_to_add = 10
        for i in range(10):
            items = np.where(train_labels == i)[0]
            inds = np.random.choice(len(items), num_to_add, replace=False)
            sup_by_label[i] = np.vstack(
                [sup_by_label[i], train_images[items[inds]]])

    add_random_samples = 0
    if add_random_samples > 0:
        rng = np.random.RandomState()
        indices = rng.choice(len(train_images), add_random_samples, False)

        for i in indices:
            l = train_labels[i]
            sup_by_label[l] = np.vstack([sup_by_label[l], [train_images[i]]])
            print(l)

    graph = tf.Graph()
    with graph.as_default():
        model = semisup.SemisupModel(semisup.architectures.mnist_model,
                                     NUM_LABELS,
                                     IMAGE_SHAPE,
                                     dropout_keep_prob=0.8)

        # Set up inputs.
        if FLAGS.random_batches:
            sup_lbls = np.asarray(
                np.hstack([
                    np.ones(len(i)) * ind for ind, i in enumerate(sup_by_label)
                ]), np.int)
            sup_images = np.vstack(sup_by_label)
            t_sup_images, t_sup_labels = semisup.create_input(
                sup_images, sup_lbls, FLAGS.sup_per_batch * NUM_LABELS)
        else:
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        if FLAGS.semisup:
            if FLAGS.equal_cls_unsup:
                allimgs_bylabel = semisup.sample_by_label(
                    train_images, train_labels, 5000, NUM_LABELS, seed)
                t_unsup_images, _ = semisup.create_per_class_inputs(
                    allimgs_bylabel, FLAGS.sup_per_batch)
            else:
                t_unsup_images, _ = semisup.create_input(
                    train_images, train_labels, FLAGS.unsup_batch_size)

            t_unsup_emb = model.image_to_embedding(t_unsup_images)
            model.add_semisup_loss(t_sup_emb,
                                   t_unsup_emb,
                                   t_sup_labels,
                                   walker_weight=FLAGS.walker_weight,
                                   visit_weight=FLAGS.visit_weight,
                                   proximity_weight=FLAGS.proximity_weight)
        model.add_logit_loss(t_sup_logit, t_sup_labels)

        t_learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                                     model.step,
                                                     FLAGS.decay_steps,
                                                     FLAGS.decay_factor,
                                                     staircase=True)
        train_op = model.create_train_op(t_learning_rate)
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for step in range(FLAGS.max_steps):
            _, summaries = sess.run([train_op, summary_op])
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

                saver.save(sess, FLAGS.logdir, model.step)

        coord.request_stop()
        coord.join(threads)
Пример #5
0
def main(_):
  FLAGS.emb_size = 128
  optimizer = 'adam'

  train_images, train_labels = mnist_tools.get_data('train')
  test_images, test_labels = mnist_tools.get_data('test')

  # Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else np.random.randint(0, 1000)
  print('Seed:', seed)

  sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                         FLAGS.sup_per_class, NUM_LABELS, seed)

  graph = tf.Graph()
  with graph.as_default():
    model_func = semisup.architectures.mnist_model
    if FLAGS.dropout_keep_prob < 1:
        model_func = semisup.architectures.mnist_model_dropout

    model = semisup.SemisupModel(model_func, NUM_LABELS,
                                 IMAGE_SHAPE, optimizer=optimizer, emb_size=FLAGS.emb_size,
                                 dropout_keep_prob=FLAGS.dropout_keep_prob)

    # Set up inputs.
    t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
      sup_by_label, FLAGS.sup_per_batch)

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(t_sup_images)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    t_unsup_images = tf.placeholder("float", shape=[None] + IMAGE_SHAPE)

    proximity_weight = tf.placeholder("float", shape=[])
    visit_weight = tf.placeholder("float", shape=[])
    walker_weight = tf.placeholder("float", shape=[])
    t_logit_weight = tf.placeholder("float", shape=[])
    t_l1_weight = tf.placeholder("float", shape=[])
    t_learning_rate = tf.placeholder("float", shape=[])

    t_unsup_emb = model.image_to_embedding(t_unsup_images)
    model.add_semisup_loss(
      t_sup_emb, t_unsup_emb, t_sup_labels,
      walker_weight=walker_weight, visit_weight=visit_weight, proximity_weight=proximity_weight)
    model.add_logit_loss(t_sup_logit, t_sup_labels, weight=t_logit_weight)

    model.add_emb_regularization(t_sup_emb, weight=t_l1_weight)
    model.add_emb_regularization(t_unsup_emb, weight=t_l1_weight)

    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()

  sess = tf.InteractiveSession(graph=graph)

  unsup_images_iterator = semisup.create_input(train_images, train_labels,
                                               FLAGS.unsup_batch_size)
  tf.global_variables_initializer().run()

  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)

  use_new_visit_loss = False
  for step in range(FLAGS.max_steps):
    unsup_images, _ = sess.run(unsup_images_iterator)

    if use_new_visit_loss:
      _, summaries, train_loss = sess.run([train_op, summary_op, model.train_loss], {
        t_unsup_images: unsup_images,
        walker_weight: FLAGS.walker_weight,
        proximity_weight: 0.3 + apply_envelope("lin", step, 0.7, FLAGS.warmup_steps, 0)
                          - apply_envelope("lin", step, FLAGS.visit_weight, 2000, FLAGS.warmup_steps),
        t_logit_weight: FLAGS.logit_weight,
        t_l1_weight: FLAGS.l1_weight,
        visit_weight: apply_envelope("lin", step, FLAGS.visit_weight, 2000, FLAGS.warmup_steps),
        t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
      })
    else:
      _, summaries, train_loss = sess.run([train_op, summary_op, model.train_loss], {
        t_unsup_images: unsup_images,
        walker_weight: FLAGS.walker_weight,
        visit_weight: 0.3 + apply_envelope("lin", step, 0.7, FLAGS.warmup_steps, 0),
        proximity_weight: 0,
        t_logit_weight: FLAGS.logit_weight,
        t_l1_weight: FLAGS.l1_weight,
        t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
      })

    if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
      print('Step: %d' % step)
      test_pred = model.classify(test_images).argmax(-1)
      conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
      test_err = (test_labels != test_pred).mean() * 100
      print(conf_mtx)
      print('Test error: %.2f %%' % test_err)
      print('Train loss: %.2f ' % train_loss)
      print()

      test_summary = tf.Summary(
        value=[tf.Summary.Value(
          tag='Test Err', simple_value=test_err)])

      summary_writer.add_summary(summaries, step)
      summary_writer.add_summary(test_summary, step)
def main(_):
    if FLAGS.logdir is not None:
        FLAGS.logdir = FLAGS.logdir + '/t_mnist_eval'
        try:
            shutil.rmtree(FLAGS.logdir)
        except OSError as e:
            print ("Error: %s - %s." % (e.filename, e.strerror))

    train_images, train_labels = mnist_tools.get_data('train')
    test_images, test_labels = mnist_tools.get_data('test')

    # Sample labeled training subset.
    if FLAGS.sup_seed >= 0:
      seed = FLAGS.sup_seed
    elif FLAGS.sup_seed == -2:
      seed = FLAGS.sup_per_class
    else:
      seed = np.random.randint(0, 1000)

    print('Seed:', seed)
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, NUM_LABELS, seed)


    graph = tf.Graph()
    with graph.as_default():
        model = semisup.SemisupModel(semisup.architectures.mnist_model, NUM_LABELS,
                                     IMAGE_SHAPE, dropout_keep_prob=FLAGS.dropout_keep_prob)

        # Set up inputs.
        t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                    sup_by_label, FLAGS.sup_per_batch)

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        if FLAGS.semisup:
            t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                                         FLAGS.unsup_batch_size)

            t_unsup_emb = model.image_to_embedding(t_unsup_images)
            model.add_semisup_loss(
                    t_sup_emb, t_unsup_emb, t_sup_labels,
                    walker_weight=FLAGS.walker_weight, visit_weight=FLAGS.visit_weight)

            #model.add_emb_regularization(t_unsup_emb, weight=FLAGS.l1_weight)

        model.add_logit_loss(t_sup_logit, t_sup_labels, weight=FLAGS.logit_weight)

        #model.add_emb_regularization(t_sup_emb, weight=FLAGS.l1_weight)

        t_learning_rate = tf.placeholder("float", shape=[])

        train_op = model.create_train_op(t_learning_rate)

        summary_op = tf.summary.merge_all()
        if FLAGS.logdir is not None:
            summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)
            saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        learning_rate_ = FLAGS.learning_rate

        for step in range(FLAGS.max_steps):
            lr = learning_rate_
            if step < FLAGS.warmup_steps:
                lr = 1e-6 + semisup.apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)

            _, summaries = sess.run([train_op, summary_op], {
              t_learning_rate: lr
            })

            sys.stderr.write("\rstep: %d" % step)
            sys.stdout.flush()
    
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('\nStep: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)

    
                if FLAGS.logdir is not None:
                    sum_values = {
                    'Test error': test_err
                }
                summary_writer.add_summary(summaries, step)
                for key, value in sum_values.items():
                    summary = tf.Summary(
                            value=[tf.Summary.Value(tag=key, simple_value=value)])
                    summary_writer.add_summary(summary, step)

            if step % FLAGS.decay_steps == 0 and step > 0:
                learning_rate_ = learning_rate_ * FLAGS.decay_factor

        coord.request_stop()
        coord.join(threads)

    print('FINAL RESULTS:')
    print('Test error: %.2f %%' % (test_err))
    print('final_score', 1 - test_err/100)

    print('@@test_error:%.4f' % (test_err/100))
    print('@@train_loss:%.4f' % 0)
    print('@@reg_loss:%.4f' % 0)
    print('@@estimated_error:%.4f' % 0)
    print('@@centroid_norm:%.4f' % 0)
    print('@@emb_norm:%.4f' % 0)
    print('@@k_score:%.4f' % 0)
    print('@@svm_score:%.4f' % 0)
def main(_):
    if FLAGS.logdir is not None:
        # FLAGS.logdir = FLAGS.logdir + '/t_' + datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
        _unsup_batch_size = FLAGS.unsup_batch_size if FLAGS.semisup else 0
        FLAGS.logdir = "{0}/i{1}_e{2}_s{3}_un{4}_d{5}_{6}_w{7}".format(
            FLAGS.logdir, image_shape[0], FLAGS.emb_size,
            FLAGS.sup_per_class, _unsup_batch_size, FLAGS.decay_steps,
            int(FLAGS.decay_factor * 100), FLAGS.warmup_steps)
        try:
            shutil.rmtree(FLAGS.logdir)
        except OSError as e:
            print("Error: %s - %s." % (e.filename, e.strerror))

    # Load image data from npy file
    train_images, test_images, train_labels, test_labels = dataset_tools.get_data(
        one_hot=False, test_size=FLAGS.test_size, image_shape=image_shape)

    unique, counts = np.unique(test_labels, return_counts=True)
    testset_distribution = dict(zip(unique, counts))

    train_X, train_Y = semisup.sample_by_label_v2(train_images, train_labels,
                                                  FLAGS.sup_per_class,
                                                  NUM_LABELS,
                                                  np.random.randint(0, 100))

    def aug(image, label):
        return apply_augmentation(
            image,
            target_shape=image_shape,
            params=dataset_tools.augmentation_params), label

    def aug_unsup(image):
        return apply_augmentation(image,
                                  target_shape=image_shape,
                                  params=dataset_tools.augmentation_params)

    graph = tf.Graph()
    with graph.as_default():
        # Create function that defines the network.
        architecture = getattr(semisup.architectures, FLAGS.architecture)
        model_function = partial(architecture,
                                 new_shape=None,
                                 img_shape=IMAGE_SHAPE,
                                 batch_norm_decay=FLAGS.batch_norm_decay,
                                 emb_size=FLAGS.emb_size)

        model = semisup.SemisupModel(model_function,
                                     NUM_LABELS,
                                     IMAGE_SHAPE,
                                     emb_size=FLAGS.emb_size,
                                     dropout_keep_prob=FLAGS.dropout_keep_prob)

        # Set up supervised inputs.
        t_images = tf.placeholder("float", shape=[None] + image_shape)
        t_labels = tf.placeholder(train_labels.dtype, shape=[None])
        dataset = Dataset.from_tensor_slices((t_images, t_labels))
        # Apply augmentation
        if FLAGS.augmentation:
            dataset = dataset.map(aug)
        dataset = dataset.shuffle(buffer_size=FLAGS.sup_per_class * NUM_LABELS)
        dataset = dataset.repeat().batch(FLAGS.sup_per_class * NUM_LABELS)
        iterator = dataset.make_initializable_iterator()
        t_sup_images, t_sup_labels = iterator.get_next()

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        if FLAGS.semisup:
            unsup_t_images = tf.placeholder("float",
                                            shape=[None] + image_shape)
            unsup_dataset = Dataset.from_tensor_slices(unsup_t_images)
            # Apply augmentation
            if FLAGS.augmentation:
                unsup_dataset = unsup_dataset.map(aug_unsup)
            unsup_dataset = unsup_dataset.shuffle(
                buffer_size=FLAGS.unsup_batch_size)
            unsup_dataset = unsup_dataset.repeat().batch(
                FLAGS.unsup_batch_size)
            unsup_iterator = unsup_dataset.make_initializable_iterator()
            t_unsup_images = unsup_iterator.get_next()

            t_unsup_emb = model.image_to_embedding(t_unsup_images)
            model.add_semisup_loss(t_sup_emb,
                                   t_unsup_emb,
                                   t_sup_labels,
                                   walker_weight=FLAGS.walker_weight,
                                   visit_weight=FLAGS.visit_weight)

            #model.add_emb_regularization(t_unsup_emb, weight=FLAGS.l1_weight)
        else:
            model.loss_aba = tf.constant(0)
            model.visit_loss = tf.constant(0)

        t_logit_loss = model.add_logit_loss(t_sup_logit,
                                            t_sup_labels,
                                            weight=FLAGS.logit_weight)
        # t_logit_loss = tf.constant(0)

        #model.add_emb_regularization(t_sup_emb, weight=FLAGS.l1_weight)

        t_learning_rate = tf.placeholder("float", shape=[])

        train_op = model.create_train_op(t_learning_rate)

        summary_op = tf.summary.merge_all()

        if FLAGS.logdir is not None:
            summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)
            saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:

        sess.run(iterator.initializer,
                 feed_dict={
                     t_images: train_X,
                     t_labels: train_Y
                 })
        if FLAGS.semisup:
            sess.run(unsup_iterator.initializer,
                     feed_dict={unsup_t_images: train_images})

        tf.global_variables_initializer().run()

        if FLAGS.restore_checkpoint is not None:
            variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                          scope='net')
            restorer = tf.train.Saver(var_list=variables)
            restorer.restore(sess, FLAGS.restore_checkpoint)

        learning_rate_ = FLAGS.learning_rate

        for step in range(FLAGS.max_steps):
            lr = learning_rate_
            if step < FLAGS.warmup_steps:
                lr = 1e-6 + semisup.apply_envelope(
                    "log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)

            step_start_time = time.time()
            _, summaries, train_loss, aba_loss, visit_loss, logit_loss = sess.run(
                [
                    train_op, summary_op, model.train_loss, model.loss_aba,
                    model.visit_loss, t_logit_loss
                ], {t_learning_rate: lr})

            if not FLAGS.run_in_background:
                sys.stderr.write("\rstep: %d, Step time: %.4f sec" %
                                 (step, (time.time() - step_start_time)))
                sys.stdout.flush()
            # sup_images = sess.run(d_sup_image

            if (step +
                    1) % FLAGS.eval_interval == 0 or step == 99 or step == 0:
                print('\n=======================')
                print('Step: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Target:', testset_distribution)
                print('Test error: %.2f%%' % test_err)
                print('Learning rate:', lr)
                print('train_loss:', train_loss)
                if FLAGS.semisup:
                    print('walker_loss_aba:', aba_loss)
                    print('visit_loss:', visit_loss)
                print('logit_loss:', logit_loss)
                print('Image shape:', IMAGE_SHAPE)
                print('emb_size: ', FLAGS.emb_size)
                print('sup_per_class: ', FLAGS.sup_per_class)
                if FLAGS.semisup:
                    print('unsup_batch_size: ', FLAGS.unsup_batch_size)
                print('semisup: ', FLAGS.semisup)
                print('augmentation: ', FLAGS.augmentation)
                print('decay_steps: ', FLAGS.decay_steps)
                print('decay_factor: ', FLAGS.decay_factor)
                print('warmup_steps: ', FLAGS.warmup_steps)
                if FLAGS.semisup:
                    print('walker_weight: ', FLAGS.walker_weight)
                    print('visit_weight: ', FLAGS.visit_weight)
                print('logit_weight: ', FLAGS.logit_weight)
                print('=======================\n')

                if FLAGS.logdir is not None:
                    sum_values = {'Test error': test_err}
                    summary_writer.add_summary(summaries, step)
                    for key, value in sum_values.items():
                        summary = tf.Summary(value=[
                            tf.Summary.Value(tag=key, simple_value=value)
                        ])
                        summary_writer.add_summary(summary, step)

            if FLAGS.logdir is not None and (step + 1) % 5000 == 0:
                path = saver.save(sess, FLAGS.logdir + '/checkpoint',
                                  model.step)
                print('@@model_path:%s' % path)

            if step % FLAGS.decay_steps == 0 and step > 0:
                learning_rate_ = learning_rate_ * FLAGS.decay_factor
def main(_):
  unsup_multiplier = NUM_TRAIN_IMAGES / NUM_LABELS / FLAGS.sup_per_class
  print(unsup_multiplier)

# Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None

  graph = tf.Graph()
  with graph.as_default():
    model = semisup.SemisupModel(semisup.architectures.densenet_model, NUM_LABELS,
                                 IMAGE_SHAPE)

    # Set up inputs.
    train_sup, train_labels_sup = data.build_input(FLAGS.cifar,
                                                   os.path.join(FLAGS.dataset_dir, TRAIN_FILE),
                                                   batch_size=FLAGS.sup_batch_size,
                                                   mode='train',
                                                   subset_factor=unsup_multiplier)

    if FLAGS.unsup_batch_size > 0:
      train_unsup, train_labels_unsup = data.build_input(FLAGS.cifar,
                                                       os.path.join(FLAGS.dataset_dir, TRAIN_FILE),
                                                       batch_size=FLAGS.unsup_batch_size,
                                                       mode='train')

    test_images, test_labels = data.build_input(FLAGS.cifar,
                                                os.path.join(FLAGS.dataset_dir, TEST_FILE),
                                                batch_size=TEST_BATCH_SIZE,
                                                mode='test')

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(train_sup)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    # Add losses.
    if FLAGS.unsup_batch_size > 0:
      t_unsup_emb = model.image_to_embedding(train_unsup)
      model.add_semisup_loss(
            t_sup_emb, t_unsup_emb, train_labels_sup, visit_weight=FLAGS.visit_weight)

    model.add_logit_loss(t_sup_logit, train_labels_sup)

    t_learning_rate = tf.train.exponential_decay(
        FLAGS.learning_rate,
        model.step,
        FLAGS.decay_epochs * steps_per_epoch,
        FLAGS.decay_factor,
        staircase=True)
    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    tf.global_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    last_epoch = -1

    for step in range(FLAGS.max_epochs * int(steps_per_epoch)):
      _, summaries, tl = sess.run([train_op, summary_op, model.train_loss])

      epoch = math.floor(step / steps_per_epoch)
      if (epoch >= 0 and epoch % FLAGS.eval_interval == 0) or epoch == 1:
        if epoch == last_epoch: #don't log twice for same epoch
          continue

        last_epoch = epoch
        num_total_batches = 10000 / TEST_BATCH_SIZE
        print('Epoch: %d' % epoch)

        t_imgs, t_lbls = model.get_images(test_images, test_labels, num_total_batches, sess)
        test_pred = model.classify(t_imgs).argmax(-1)
        conf_mtx = semisup.confusion_matrix(t_lbls, test_pred, NUM_LABELS)
        test_err = (t_lbls != test_pred).mean() * 100
        print(conf_mtx)
        print('Test error: %.2f %%' % test_err)
        print()

        t_imgs, t_lbls = model.get_images(train_sup, train_labels_sup, num_total_batches, sess)
        train_pred = model.classify(t_imgs).argmax(-1)
        train_err = (t_lbls != train_pred).mean() * 100
        print('Train error: %.2f %%' % train_err)

        test_summary = tf.Summary(
            value=[tf.Summary.Value(
                tag='Test Err', simple_value=test_err)])

        summary_writer.add_summary(summaries, step)
        summary_writer.add_summary(test_summary, step)

        saver.save(sess, FLAGS.logdir, model.step)

    coord.request_stop()
    coord.join(threads)
def main(_):
  train_images, train_labels = mnist_tools.get_data('train')
  test_images, test_labels = mnist_tools.get_data('test')

  # Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
  sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                         FLAGS.sup_per_class, NUM_LABELS, seed)

  import numpy as np
  if 0:

    indices = [374, 2507, 9755, 12953, 16507, 16873, 23474,
            23909, 30280, 35070, 49603, 50106, 51171, 51726, 51805, 55205, 57251, 57296, 57779, 59154] + \
               [16644, 45576, 52886, 42140, 29201, 7767, 24, 134, 8464, 15022, 15715, 15602, 11030, 3898, 10195, 1454,
                3290, 5293, 5806, 274]
    indices = [374, 2507, 9755, 12953, 16507, 16873, 23474,
            23909, 30280, 35070, 49603, 50106, 51171, 51726, 51805, 55205, 57251, 57296, 57779, 59154] + \
               [9924, 34058, 53476, 15715, 6428, 33598, 33464, 41753, 21250, 26389, 12950,
                12464, 3795, 6761, 5638, 3952, 8300, 5632, 1475, 1875]
    sup_by_label = [ [] for i in range(NUM_LABELS)]
    for ind in indices:
      i = train_labels[ind]
      sup_by_label[i] = sup_by_label[i] + [train_images[ind]]

    for i in range(10):
      sup_by_label[i] = np.asarray(sup_by_label[i])

    sup_by_label = np.asarray(sup_by_label)


  add_random_samples = 20
  if add_random_samples > 0:
    rng = np.random.RandomState()
    indices = rng.choice(len(train_images), add_random_samples, False)

    for i in indices:
      l = train_labels[i]
      sup_by_label[l] = np.vstack([sup_by_label[l], [train_images[i]]])


  graph = tf.Graph()
  with graph.as_default():
    model = semisup.SemisupModel(semisup.architectures.mnist_model, NUM_LABELS,
                                 IMAGE_SHAPE)

    # Set up inputs.
    t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                             FLAGS.unsup_batch_size)
    t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
        sup_by_label, FLAGS.sup_per_batch)

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(t_sup_images)
    t_unsup_emb = model.image_to_embedding(t_unsup_images)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    # Add losses.
    model.add_semisup_loss(
        t_sup_emb, t_unsup_emb, t_sup_labels, visit_weight=FLAGS.visit_weight)
    model.add_logit_loss(t_sup_logit, t_sup_labels)

    t_learning_rate = tf.train.exponential_decay(
        FLAGS.learning_rate,
        model.step,
        FLAGS.decay_steps,
        FLAGS.decay_factor,
        staircase=True)
    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    tf.global_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for step in range(FLAGS.max_steps):
      _, summaries = sess.run([train_op, summary_op])
      if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
        print('Step: %d' % step)
        test_pred = model.classify(test_images).argmax(-1)
        conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
        test_err = (test_labels != test_pred).mean() * 100
        print(conf_mtx)
        print('Test error: %.2f %%' % test_err)
        print()

        test_summary = tf.Summary(
            value=[tf.Summary.Value(
                tag='Test Err', simple_value=test_err)])

        summary_writer.add_summary(summaries, step)
        summary_writer.add_summary(test_summary, step)

        saver.save(sess, FLAGS.logdir, model.step)

    coord.request_stop()
    coord.join(threads)
Пример #10
0
def main(_):

# Sample labeled training subset.

  train_images, train_labels = data.load_training_data(FLAGS.cifar)
  test_images, test_labels = data.load_test_data(FLAGS.cifar)

  print(train_images.shape, train_labels.shape)

  # Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else np.random.randint(0, 1000)
  print('Seed:', seed)
  sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                         FLAGS.sup_per_class, NUM_LABELS, seed)

  graph = tf.Graph()
  with graph.as_default():

    def augment(image):
        # image_size = 28
        # image = tf.image.resize_image_with_crop_or_pad(
        #    image, image_size+4, image_size+4)
        # image = tf.random_crop(image, [image_size, image_size, 3])
        image = tf.image.random_flip_left_right(image)
        # Brightness/saturation/constrast provides small gains .2%~.5% on cifar.
        # image = tf.image.random_brightness(image, max_delta=63. / 255.)
        # image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        # image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
        return image

    model_func = getattr(semisup.architectures, FLAGS.model)
    model = semisup.SemisupModel(model_func, NUM_LABELS,
                                 IMAGE_SHAPE, emb_size=256, dropout_keep_prob=FLAGS.dropout_keep_prob,
                                 augmentation_function=augment)


    if FLAGS.random_batches:
      sup_lbls = np.asarray(np.hstack([np.ones(len(i)) * ind for ind, i in enumerate(sup_by_label)]), np.int)
      sup_images = np.vstack(sup_by_label)
      batch_size = FLAGS.sup_per_batch * NUM_LABELS
      t_sup_images, t_sup_labels = semisup.create_input(sup_images, sup_lbls, batch_size)
    else:
      t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
        sup_by_label, FLAGS.sup_per_batch)

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(t_sup_images)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    # Add losses.
    if FLAGS.semisup:
      if FLAGS.equal_cls_unsup:
        allimgs_bylabel = semisup.sample_by_label(train_images, train_labels,
                                                  int(NUM_TRAIN_IMAGES / NUM_LABELS), NUM_LABELS, seed)
        t_unsup_images, _ = semisup.create_per_class_inputs(
          allimgs_bylabel, FLAGS.sup_per_batch)
      else:
        t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                                 FLAGS.unsup_batch_size)

      t_unsup_emb = model.image_to_embedding(t_unsup_images)
      proximity_weight = tf.placeholder("float", shape=[])
      visit_weight = tf.placeholder("float", shape=[])
      walker_weight = tf.placeholder("float", shape=[])

      model.add_semisup_loss(
        t_sup_emb, t_unsup_emb, t_sup_labels,
        walker_weight=walker_weight, visit_weight=visit_weight, proximity_weight=proximity_weight)
    model.add_logit_loss(t_sup_logit, t_sup_labels)

    t_learning_rate = tf.train.exponential_decay(
      FLAGS.learning_rate,
      model.step,
      FLAGS.decay_steps,
      FLAGS.decay_factor,
      staircase=True)
    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()
  with tf.Session(graph=graph) as sess:
    tf.global_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for step in range(FLAGS.max_steps):
      feed_dict={}
      if FLAGS.semisup:

        if FLAGS.proximity_loss:
          feed_dict = {
            walker_weight:  FLAGS.walker_weight,#0.1 + apply_envelope("lin", step, FLAGS.walker_weight, 500, FLAGS.decay_steps),
            visit_weight: 0,
            proximity_weight: FLAGS.visit_weight#0.1 + apply_envelope("lin", step,FLAGS.visit_weight, 500, FLAGS.decay_steps),
          #t_logit_weight: FLAGS.logit_weight,
          #t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
          }
        else:
          feed_dict =  {
            walker_weight:  FLAGS.walker_weight,#0.1 + apply_envelope("lin", step, FLAGS.walker_weight, 500, FLAGS.decay_steps),
            visit_weight: FLAGS.visit_weight, #0.1 + apply_envelope("lin", step,FLAGS.visit_weight, 500, FLAGS.decay_steps),
            proximity_weight: 0
          #t_logit_weight: FLAGS.logit_weight,
          #t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
          }

      _, summaries, train_loss = sess.run([train_op, summary_op, model.train_loss], feed_dict)


      if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
        print('Step: %d' % step)
        test_pred = model.classify(test_images, sess).argmax(-1)
        conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
        test_err = (test_labels != test_pred).mean() * 100
        print(conf_mtx)
        print('Test error: %.2f %%' % test_err)
        print()

        test_summary = tf.Summary(
            value=[tf.Summary.Value(
                tag='Test Err', simple_value=test_err)])

        summary_writer.add_summary(summaries, step)
        summary_writer.add_summary(test_summary, step)

        #saver.save(sess, FLAGS.logdir, model.step)

    coord.request_stop()
    coord.join(threads)