Exemplo n.º 1
0
    def init_iterators():
        input_graph = tf.Graph()
        with input_graph.as_default():
            sup_images_iterator, sup_labels_iterator = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            input_sess = tf.Session(graph=input_graph)
            input_sess.run(tf.global_variables_initializer())

        input_coord = tf.train.Coordinator()
        input_threads = tf.train.start_queue_runners(sess=input_sess,
                                                     coord=input_coord)

        return sup_images_iterator, sup_labels_iterator, input_sess, input_threads, input_coord
Exemplo n.º 2
0
def main(_):
    train_images, train_labels = mnist_tools.get_data('train')
    test_images, test_labels = mnist_tools.get_data('test')

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, NUM_LABELS,
                                           seed)

    graph = tf.Graph()
    with graph.as_default():
        model = semisup.SemisupModel(semisup.architectures.mnist_model,
                                     NUM_LABELS, IMAGE_SHAPE)

        # Set up inputs.
        t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                                 FLAGS.unsup_batch_size)
        t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
            sup_by_label, FLAGS.sup_per_batch)

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_unsup_emb = model.image_to_embedding(t_unsup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        model.add_semisup_loss(t_sup_emb,
                               t_unsup_emb,
                               t_sup_labels,
                               visit_weight=FLAGS.visit_weight)
        model.add_logit_loss(t_sup_logit, t_sup_labels)

        t_learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                                     model.step,
                                                     FLAGS.decay_steps,
                                                     FLAGS.decay_factor,
                                                     staircase=True)
        train_op = model.create_train_op(t_learning_rate)
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for step in xrange(FLAGS.max_steps):
            _, summaries = sess.run([train_op, summary_op])
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

                saver.save(sess, FLAGS.logdir, model.step)

        coord.request_stop()
        coord.join(threads)
Exemplo n.º 3
0
def main(_):
    train_images, train_labels = mnist_tools.get_data('train')
    test_images, test_labels = mnist_tools.get_data('test')

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else np.random.randint(
        0, 1000)
    print('Seed:', seed)
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, NUM_LABELS,
                                           seed)

    # add twos, fours, and sixes
    if FLAGS.testadd:
        num_to_add = 10
        for i in range(10):
            items = np.where(train_labels == i)[0]
            inds = np.random.choice(len(items), num_to_add, replace=False)
            sup_by_label[i] = np.vstack(
                [sup_by_label[i], train_images[items[inds]]])

    add_random_samples = 0
    if add_random_samples > 0:
        rng = np.random.RandomState()
        indices = rng.choice(len(train_images), add_random_samples, False)

        for i in indices:
            l = train_labels[i]
            sup_by_label[l] = np.vstack([sup_by_label[l], [train_images[i]]])
            print(l)

    graph = tf.Graph()
    with graph.as_default():
        model = semisup.SemisupModel(semisup.architectures.mnist_model,
                                     NUM_LABELS,
                                     IMAGE_SHAPE,
                                     dropout_keep_prob=0.8)

        # Set up inputs.
        if FLAGS.random_batches:
            sup_lbls = np.asarray(
                np.hstack([
                    np.ones(len(i)) * ind for ind, i in enumerate(sup_by_label)
                ]), np.int)
            sup_images = np.vstack(sup_by_label)
            t_sup_images, t_sup_labels = semisup.create_input(
                sup_images, sup_lbls, FLAGS.sup_per_batch * NUM_LABELS)
        else:
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        if FLAGS.semisup:
            if FLAGS.equal_cls_unsup:
                allimgs_bylabel = semisup.sample_by_label(
                    train_images, train_labels, 5000, NUM_LABELS, seed)
                t_unsup_images, _ = semisup.create_per_class_inputs(
                    allimgs_bylabel, FLAGS.sup_per_batch)
            else:
                t_unsup_images, _ = semisup.create_input(
                    train_images, train_labels, FLAGS.unsup_batch_size)

            t_unsup_emb = model.image_to_embedding(t_unsup_images)
            model.add_semisup_loss(t_sup_emb,
                                   t_unsup_emb,
                                   t_sup_labels,
                                   walker_weight=FLAGS.walker_weight,
                                   visit_weight=FLAGS.visit_weight,
                                   proximity_weight=FLAGS.proximity_weight)
        model.add_logit_loss(t_sup_logit, t_sup_labels)

        t_learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                                     model.step,
                                                     FLAGS.decay_steps,
                                                     FLAGS.decay_factor,
                                                     staircase=True)
        train_op = model.create_train_op(t_learning_rate)
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for step in range(FLAGS.max_steps):
            _, summaries = sess.run([train_op, summary_op])
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('Step: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred,
                                                    NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)
                print()

                test_summary = tf.Summary(value=[
                    tf.Summary.Value(tag='Test Err', simple_value=test_err)
                ])

                summary_writer.add_summary(summaries, step)
                summary_writer.add_summary(test_summary, step)

                saver.save(sess, FLAGS.logdir, model.step)

        coord.request_stop()
        coord.join(threads)
Exemplo n.º 4
0
def main(argv):
    del argv

    # Load data.
    dataset_tools = import_module('tools.' + FLAGS.dataset)
    train_images, train_labels = dataset_tools.get_data('train')
    if FLAGS.target_dataset is not None:
        target_dataset_tools = import_module('tools.' + FLAGS.target_dataset)
        train_images_unlabeled, _ = target_dataset_tools.get_data(
            FLAGS.target_dataset_split)
    else:
        train_images_unlabeled, _ = dataset_tools.get_data('unlabeled')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled, None,
                                                  FLAGS.unsup_batch_size)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) generalize augmentation
                def _random_invert(inputs, _):
                    randu = tf.random_uniform(
                        shape=[FLAGS.sup_per_batch * num_labels],
                        minval=0.,
                        maxval=1.,
                        dtype=tf.float32)
                    randu = tf.cast(tf.less(randu, 0.5), tf.float32)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    inputs = tf.cast(inputs, tf.float32)
                    return tf.abs(inputs - 255 * randu)

                augmentation_function = _random_invert
            else:
                augmentation_function = None

            # Create function that defines the network.
            model_function = partial(
                architecture,
                new_shape=new_shape,
                img_shape=image_shape,
                augmentation_function=augmentation_function,
                batch_norm_decay=FLAGS.batch_norm_decay,
                emb_size=FLAGS.emb_size)

            # Set up semisup model.
            model = semisup.SemisupModel(model_function, num_labels,
                                         image_shape)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)
            t_unsup_emb = model.image_to_embedding(t_unsup_images)

            # Add virtual embeddings.
            if FLAGS.virtual_embeddings:
                t_sup_emb = tf.concat(0, [
                    t_sup_emb,
                    semisup.create_virt_emb(FLAGS.virtual_embeddings,
                                            FLAGS.emb_size)
                ])

                if not FLAGS.remove_classes:
                    # need to add additional labels for virtual embeddings
                    t_sup_labels = tf.concat(0, [
                        t_sup_labels,
                        (num_labels +
                         tf.range(1, FLAGS.virtual_embeddings + 1, tf.int64)) *
                        tf.ones([FLAGS.virtual_embeddings], tf.int64)
                    ])

            t_sup_logit = model.embedding_to_logit(t_sup_emb)

            # Add losses.
            visit_weight_envelope_steps = (FLAGS.walker_weight_envelope_steps
                                           if FLAGS.visit_weight_envelope_steps
                                           == -1 else
                                           FLAGS.visit_weight_envelope_steps)
            visit_weight_envelope_delay = (FLAGS.walker_weight_envelope_delay
                                           if FLAGS.visit_weight_envelope_delay
                                           == -1 else
                                           FLAGS.visit_weight_envelope_delay)
            visit_weight = apply_envelope(
                type=FLAGS.visit_weight_envelope,
                step=model.step,
                final_weight=FLAGS.visit_weight,
                growing_steps=visit_weight_envelope_steps,
                delay=visit_weight_envelope_delay)
            walker_weight = apply_envelope(
                type=FLAGS.walker_weight_envelope,
                step=model.step,
                final_weight=FLAGS.walker_weight,
                growing_steps=FLAGS.walker_weight_envelope_steps,  # pylint:disable=line-too-long
                delay=FLAGS.walker_weight_envelope_delay)
            tf.summary.scalar('Weights_Visit', visit_weight)
            tf.summary.scalar('Weights_Walker', walker_weight)

            if FLAGS.unsup_samples != 0:
                model.add_semisup_loss(t_sup_emb,
                                       t_unsup_emb,
                                       t_sup_labels,
                                       visit_weight=visit_weight,
                                       walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)

            # Set up learning rate
            t_learning_rate = tf.maximum(
                tf.train.exponential_decay(FLAGS.learning_rate,
                                           model.step,
                                           FLAGS.decay_steps,
                                           FLAGS.decay_factor,
                                           staircase=True),
                FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            train_op = model.create_train_op(t_learning_rate)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)  # pylint:disable=line-too-long

            slim.learning.train(
                train_op,
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
            )
Exemplo n.º 5
0
def main(_):
    # Load data.
    dataset_tools = import_module('tools.' + FLAGS.dataset)
    train_images, train_labels = dataset_tools.get_data('train')
    if FLAGS.target_dataset is not None:
        target_dataset_tools = import_module('tools.' + FLAGS.target_dataset)
        train_images_unlabeled, _ = target_dataset_tools.get_data(
            FLAGS.target_dataset_split)
    else:
        train_images_unlabeled, _ = dataset_tools.get_data('unlabeled')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled, None,
                                                  FLAGS.unsup_batch_size)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) revert this to the general case
                def _random_invert(inputs, _):
                    randu = tf.random_uniform(
                        shape=[FLAGS.sup_per_batch * num_labels],
                        minval=0.,
                        maxval=1.,
                        dtype=tf.float32)
                    randu = tf.cast(tf.less(randu, 0.5), tf.float32)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    inputs = tf.cast(inputs, tf.float32)
                    return tf.abs(inputs - 255 * randu)

                augmentation_function = _random_invert

                # if hasattr(dataset_tools, 'augmentation_params'):
                #    augmentation_function = partial(
                #        apply_augmentation, params=dataset_tools.augmentation_params)
                # else:
                #    augmentation_function = apply_affine_augmentation
            else:
                augmentation_function = None

            # Create function that defines the network.
            model_function = partial(
                architecture,
                new_shape=new_shape,
                img_shape=image_shape,
                augmentation_function=augmentation_function,
                batch_norm_decay=FLAGS.batch_norm_decay,
                emb_size=FLAGS.emb_size)

            # Set up semisup model.
            model = semisup.SemisupModel(model_function, num_labels,
                                         image_shape)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)
            t_unsup_emb = model.image_to_embedding(t_unsup_images)

            # Add virtual embeddings.
            if FLAGS.virtual_embeddings:
                t_sup_emb = tf.concat(0, [
                    t_sup_emb,
                    semisup.create_virt_emb(FLAGS.virtual_embeddings,
                                            FLAGS.emb_size)
                ])

                if not FLAGS.remove_classes:
                    # need to add additional labels for virtual embeddings
                    t_sup_labels = tf.concat(0, [
                        t_sup_labels,
                        (num_labels +
                         tf.range(1, FLAGS.virtual_embeddings + 1, tf.int64)) *
                        tf.ones([FLAGS.virtual_embeddings], tf.int64)
                    ])

            t_sup_logit = model.embedding_to_logit(t_sup_emb)

            # Add losses.
            if FLAGS.mmd:
                sys.path.insert(0, '/usr/wiss/haeusser/libs/opt-mmd/gan')
                from mmd import mix_rbf_mmd2

                bandwidths = [2.0, 5.0, 10.0, 20.0, 40.0, 80.0]  # original

                t_sup_flat = tf.reshape(t_sup_emb,
                                        [FLAGS.sup_per_batch * num_labels, -1])
                t_unsup_flat = tf.reshape(t_unsup_emb,
                                          [FLAGS.unsup_batch_size, -1])
                mmd_loss = mix_rbf_mmd2(t_sup_flat,
                                        t_unsup_flat,
                                        sigmas=bandwidths)
                tf.losses.add_loss(mmd_loss)
                tf.summary.scalar('MMD_loss', mmd_loss)
            else:
                visit_weight_envelope_steps = FLAGS.walker_weight_envelope_steps if FLAGS.visit_weight_envelope_steps == -1 else FLAGS.visit_weight_envelope_steps
                visit_weight_envelope_delay = FLAGS.walker_weight_envelope_delay if FLAGS.visit_weight_envelope_delay == -1 else FLAGS.visit_weight_envelope_delay
                visit_weight = apply_envelope(
                    type=FLAGS.visit_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.visit_weight,
                    growing_steps=visit_weight_envelope_steps,
                    delay=visit_weight_envelope_delay)
                walker_weight = apply_envelope(
                    type=FLAGS.walker_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.walker_weight,
                    growing_steps=FLAGS.walker_weight_envelope_steps,
                    delay=FLAGS.walker_weight_envelope_delay)
                tf.summary.scalar('Weights_Visit', visit_weight)
                tf.summary.scalar('Weights_Walker', walker_weight)

                if FLAGS.unsup_samples != 0:
                    model.add_semisup_loss(t_sup_emb,
                                           t_unsup_emb,
                                           t_sup_labels,
                                           visit_weight=visit_weight,
                                           walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)

            # Set up learning rate schedule if necessary.
            if FLAGS.custom_lr_vals is not None and FLAGS.custom_lr_steps is not None:
                boundaries = [
                    tf.convert_to_tensor(x, tf.int64)
                    for x in FLAGS.custom_lr_steps
                ]

                t_learning_rate = piecewise_constant(model.step, boundaries,
                                                     FLAGS.custom_lr_vals)
            else:
                t_learning_rate = tf.maximum(
                    tf.train.exponential_decay(FLAGS.learning_rate,
                                               model.step,
                                               FLAGS.decay_steps,
                                               FLAGS.decay_factor,
                                               staircase=True),
                    FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            train_op = model.create_train_op(t_learning_rate)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)

            ''' BEGIN EVIL STUFF !!!
            checkpoint_path = '/usr/wiss/haeusser/experiments/inception_v4/model.ckpt'
            mapping = dict()
            for x in slim.get_model_variables():
                name = x.name[:-2]
                ok = True
                for banned in ['Logits', 'fc1', 'fully_connected', 'ExponentialMovingAverage']:
                    if banned in name:
                        ok = False
                if ok:
                    mapping[name[4:]] = x

            init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
                checkpoint_path, mapping)

            # Create an initial assignment function.
            def InitAssignFn(sess):
                sess.run(init_assign_op, init_feed_dict)
                print("#################################### Checkpoint loaded.")
             '''  # END EVIL STUFF !!!

            slim.learning.train(
                train_op,
                # init_fn=InitAssignFn,  # EVIL, too
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
            )
Exemplo n.º 6
0
def main(_):
  FLAGS.emb_size = 128
  optimizer = 'adam'

  train_images, train_labels = mnist_tools.get_data('train')
  test_images, test_labels = mnist_tools.get_data('test')

  # Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else np.random.randint(0, 1000)
  print('Seed:', seed)

  sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                         FLAGS.sup_per_class, NUM_LABELS, seed)

  graph = tf.Graph()
  with graph.as_default():
    model_func = semisup.architectures.mnist_model
    if FLAGS.dropout_keep_prob < 1:
        model_func = semisup.architectures.mnist_model_dropout

    model = semisup.SemisupModel(model_func, NUM_LABELS,
                                 IMAGE_SHAPE, optimizer=optimizer, emb_size=FLAGS.emb_size,
                                 dropout_keep_prob=FLAGS.dropout_keep_prob)

    # Set up inputs.
    t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
      sup_by_label, FLAGS.sup_per_batch)

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(t_sup_images)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    t_unsup_images = tf.placeholder("float", shape=[None] + IMAGE_SHAPE)

    proximity_weight = tf.placeholder("float", shape=[])
    visit_weight = tf.placeholder("float", shape=[])
    walker_weight = tf.placeholder("float", shape=[])
    t_logit_weight = tf.placeholder("float", shape=[])
    t_l1_weight = tf.placeholder("float", shape=[])
    t_learning_rate = tf.placeholder("float", shape=[])

    t_unsup_emb = model.image_to_embedding(t_unsup_images)
    model.add_semisup_loss(
      t_sup_emb, t_unsup_emb, t_sup_labels,
      walker_weight=walker_weight, visit_weight=visit_weight, proximity_weight=proximity_weight)
    model.add_logit_loss(t_sup_logit, t_sup_labels, weight=t_logit_weight)

    model.add_emb_regularization(t_sup_emb, weight=t_l1_weight)
    model.add_emb_regularization(t_unsup_emb, weight=t_l1_weight)

    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()

  sess = tf.InteractiveSession(graph=graph)

  unsup_images_iterator = semisup.create_input(train_images, train_labels,
                                               FLAGS.unsup_batch_size)
  tf.global_variables_initializer().run()

  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)

  use_new_visit_loss = False
  for step in range(FLAGS.max_steps):
    unsup_images, _ = sess.run(unsup_images_iterator)

    if use_new_visit_loss:
      _, summaries, train_loss = sess.run([train_op, summary_op, model.train_loss], {
        t_unsup_images: unsup_images,
        walker_weight: FLAGS.walker_weight,
        proximity_weight: 0.3 + apply_envelope("lin", step, 0.7, FLAGS.warmup_steps, 0)
                          - apply_envelope("lin", step, FLAGS.visit_weight, 2000, FLAGS.warmup_steps),
        t_logit_weight: FLAGS.logit_weight,
        t_l1_weight: FLAGS.l1_weight,
        visit_weight: apply_envelope("lin", step, FLAGS.visit_weight, 2000, FLAGS.warmup_steps),
        t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
      })
    else:
      _, summaries, train_loss = sess.run([train_op, summary_op, model.train_loss], {
        t_unsup_images: unsup_images,
        walker_weight: FLAGS.walker_weight,
        visit_weight: 0.3 + apply_envelope("lin", step, 0.7, FLAGS.warmup_steps, 0),
        proximity_weight: 0,
        t_logit_weight: FLAGS.logit_weight,
        t_l1_weight: FLAGS.l1_weight,
        t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
      })

    if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
      print('Step: %d' % step)
      test_pred = model.classify(test_images).argmax(-1)
      conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
      test_err = (test_labels != test_pred).mean() * 100
      print(conf_mtx)
      print('Test error: %.2f %%' % test_err)
      print('Train loss: %.2f ' % train_loss)
      print()

      test_summary = tf.Summary(
        value=[tf.Summary.Value(
          tag='Test Err', simple_value=test_err)])

      summary_writer.add_summary(summaries, step)
      summary_writer.add_summary(test_summary, step)
def main(_):
    if FLAGS.logdir is not None:
        FLAGS.logdir = FLAGS.logdir + '/t_mnist_eval'
        try:
            shutil.rmtree(FLAGS.logdir)
        except OSError as e:
            print ("Error: %s - %s." % (e.filename, e.strerror))

    train_images, train_labels = mnist_tools.get_data('train')
    test_images, test_labels = mnist_tools.get_data('test')

    # Sample labeled training subset.
    if FLAGS.sup_seed >= 0:
      seed = FLAGS.sup_seed
    elif FLAGS.sup_seed == -2:
      seed = FLAGS.sup_per_class
    else:
      seed = np.random.randint(0, 1000)

    print('Seed:', seed)
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, NUM_LABELS, seed)


    graph = tf.Graph()
    with graph.as_default():
        model = semisup.SemisupModel(semisup.architectures.mnist_model, NUM_LABELS,
                                     IMAGE_SHAPE, dropout_keep_prob=FLAGS.dropout_keep_prob)

        # Set up inputs.
        t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                    sup_by_label, FLAGS.sup_per_batch)

        # Compute embeddings and logits.
        t_sup_emb = model.image_to_embedding(t_sup_images)
        t_sup_logit = model.embedding_to_logit(t_sup_emb)

        # Add losses.
        if FLAGS.semisup:
            t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                                         FLAGS.unsup_batch_size)

            t_unsup_emb = model.image_to_embedding(t_unsup_images)
            model.add_semisup_loss(
                    t_sup_emb, t_unsup_emb, t_sup_labels,
                    walker_weight=FLAGS.walker_weight, visit_weight=FLAGS.visit_weight)

            #model.add_emb_regularization(t_unsup_emb, weight=FLAGS.l1_weight)

        model.add_logit_loss(t_sup_logit, t_sup_labels, weight=FLAGS.logit_weight)

        #model.add_emb_regularization(t_sup_emb, weight=FLAGS.l1_weight)

        t_learning_rate = tf.placeholder("float", shape=[])

        train_op = model.create_train_op(t_learning_rate)

        summary_op = tf.summary.merge_all()
        if FLAGS.logdir is not None:
            summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)
            saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        learning_rate_ = FLAGS.learning_rate

        for step in range(FLAGS.max_steps):
            lr = learning_rate_
            if step < FLAGS.warmup_steps:
                lr = 1e-6 + semisup.apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)

            _, summaries = sess.run([train_op, summary_op], {
              t_learning_rate: lr
            })

            sys.stderr.write("\rstep: %d" % step)
            sys.stdout.flush()
    
            if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
                print('\nStep: %d' % step)
                test_pred = model.classify(test_images, sess).argmax(-1)
                conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
                test_err = (test_labels != test_pred).mean() * 100
                print(conf_mtx)
                print('Test error: %.2f %%' % test_err)

    
                if FLAGS.logdir is not None:
                    sum_values = {
                    'Test error': test_err
                }
                summary_writer.add_summary(summaries, step)
                for key, value in sum_values.items():
                    summary = tf.Summary(
                            value=[tf.Summary.Value(tag=key, simple_value=value)])
                    summary_writer.add_summary(summary, step)

            if step % FLAGS.decay_steps == 0 and step > 0:
                learning_rate_ = learning_rate_ * FLAGS.decay_factor

        coord.request_stop()
        coord.join(threads)

    print('FINAL RESULTS:')
    print('Test error: %.2f %%' % (test_err))
    print('final_score', 1 - test_err/100)

    print('@@test_error:%.4f' % (test_err/100))
    print('@@train_loss:%.4f' % 0)
    print('@@reg_loss:%.4f' % 0)
    print('@@estimated_error:%.4f' % 0)
    print('@@centroid_norm:%.4f' % 0)
    print('@@emb_norm:%.4f' % 0)
    print('@@k_score:%.4f' % 0)
    print('@@svm_score:%.4f' % 0)
Exemplo n.º 8
0
def main(argv):
    del argv

    # Load data.
    dataset_tools = import_module('tools.' + FLAGS.dataset)
    train_images, train_labels = dataset_tools.get_data('train')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks,
                                                      merge_devices=True)):

            # Set up inputs.
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            if FLAGS.remove_classes:
                t_sup_images = tf.slice(
                    t_sup_images, [0, 0, 0, 0],
                    [FLAGS.sup_per_batch * (
                        num_labels - FLAGS.remove_classes)] +
                    image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) generalize augmentation
                def _random_invert(inputs, _):
                    randu = tf.random_uniform(
                        shape=[FLAGS.sup_per_batch * num_labels], minval=0.,
                        maxval=1.,
                        dtype=tf.float32)
                    randu = tf.cast(tf.less(randu, 0.5), tf.float32)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    inputs = tf.cast(inputs, tf.float32)
                    return tf.abs(inputs - 255 * randu)

                augmentation_function = _random_invert
            else:
                augmentation_function = None

            # Create function that defines the network.
            model_function = partial(
                architecture,
                new_shape=new_shape,
                img_shape=image_shape,
                augmentation_function=augmentation_function,
                batch_norm_decay=FLAGS.batch_norm_decay,
                emb_size=FLAGS.emb_size)

            # Set up semisup model.
            model = semisup.SemisupModel(model_function, num_labels,
                                         image_shape)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)

            t_sup_logit = model.embedding_to_logit(t_sup_emb)

            # Add losses.
            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)

            variables_to_train = [v for v in tf.trainable_variables() if v.name.startswith('net/fully_connected')]
            for v in variables_to_train:
                print(v.name, v.shape)

            # Set up learning rate
            t_learning_rate = tf.maximum(
                tf.train.exponential_decay(
                    FLAGS.learning_rate,
                    model.step,
                    FLAGS.decay_steps,
                    FLAGS.decay_factor,
                    staircase=True),
                FLAGS.minimum_learning_rate)
            total_loss = tf.losses.get_total_loss()
            optimizer = tf.train.AdamOptimizer(t_learning_rate)

            variables_to_restore = None
            restore_ckpt = None
            if FLAGS.pretrained:
              variables_to_restore = [v for v in tf.trainable_variables()
                                      if not (v.name.startswith('net/fully_connected'))]
              restore_ckpt = FLAGS.pretrained
              for v in variables_to_restore:
                print(v.name)
            learning.train(graph, FLAGS.logdir,
                           total_loss, optimizer,
                           variables_to_train, model.step,
                           num_steps=FLAGS.max_steps, log_interval=20,
                           summary_interval=100, snapshot_interval=5000,
                           variables_to_restore=variables_to_restore, restore_ckpt=restore_ckpt)
            return

            # Create training operation and start the actual training loop.
            train_op = model.create_train_op(t_learning_rate)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours)  # pylint:disable=line-too-long

            slim.learning.train(
                train_op,
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
            )
Exemplo n.º 9
0
def train_test(FLAGS):

    dataset_tools = import_module('tools.' + FLAGS.dataset)
    train_images, train_labels = dataset_tools.get_data('train')
    if FLAGS.target_dataset is not None:
        target_dataset_tools = import_module('tools.' + FLAGS.target_dataset)
        train_images_unlabeled, train_images_label = target_dataset_tools.get_data(
            FLAGS.target_dataset_split)
    else:
        train_images_unlabeled, train_images_label = dataset_tools.get_data(
            'unlabeled')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))
        #TODO: make smaple slections per classs :done
        #unsup_by_label = semisup.sample_by_label(train_images_unlabeled, train_images_label,
        #                                       FLAGS.unsup_samples/num_labels+num_labels, num_labels,
        #                                       seed)

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled, None,
                                                  FLAGS.unsup_batch_size)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            #print(t_sup_images.shape)
            #with tf.Session() as sess: print (t_sup_images.eval().shape)
            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) generalize augmentation
                def _random_invert(inputs1, _):
                    inputs = tf.cast(inputs1, tf.float32)
                    inputs = tf.image.adjust_brightness(
                        inputs, tf.random_uniform((1, 1), 0.0, 0.5))
                    inputs = tf.image.random_contrast(inputs, 0.3, 1)
                    # inputs = tf.image.per_image_standardization(inputs)
                    inputs = tf.image.random_hue(inputs, 0.05)
                    inputs = tf.image.random_saturation(inputs, 0.5, 1.1)

                    def f1():
                        return tf.abs(inputs)  #annotations

                    def f2():
                        return tf.abs(inputs1)

                    return tf.cond(tf.less(tf.random_uniform([], 0.0, 1), 0.5),
                                   f1, f2)

                augmentation_function = _random_invert
            else:
                augmentation_function = None

            # Create function that defines the network.
            model_function = partial(
                architecture,
                new_shape=new_shape,
                img_shape=image_shape,
                augmentation_function=augmentation_function,
                batch_norm_decay=FLAGS.batch_norm_decay,
                emb_size=FLAGS.emb_size)

            # Set up semisup model.
            model = semisup.SemisupModel(model_function, num_labels,
                                         image_shape)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)

            t_sup_logit = model.embedding_to_logit(t_sup_emb)

            # Add losses.
            if FLAGS.unsup_samples != 0:
                t_unsup_emb = model.image_to_embedding(t_unsup_images)
                visit_weight_envelope_steps = (
                    FLAGS.walker_weight_envelope_steps
                    if FLAGS.visit_weight_envelope_steps == -1 else
                    FLAGS.visit_weight_envelope_steps)
                visit_weight_envelope_delay = (
                    FLAGS.walker_weight_envelope_delay
                    if FLAGS.visit_weight_envelope_delay == -1 else
                    FLAGS.visit_weight_envelope_delay)
                visit_weight = apply_envelope(
                    type=FLAGS.visit_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.visit_weight,
                    growing_steps=visit_weight_envelope_steps,
                    delay=visit_weight_envelope_delay)
                walker_weight = apply_envelope(
                    type=FLAGS.walker_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.walker_weight,
                    growing_steps=FLAGS.walker_weight_envelope_steps,  # pylint:disable=line-too-long
                    delay=FLAGS.walker_weight_envelope_delay)
                tf.summary.scalar('Weights_Visit', visit_weight)
                tf.summary.scalar('Weights_Walker', walker_weight)

                model.add_semisup_loss(t_sup_emb,
                                       t_unsup_emb,
                                       t_sup_labels,
                                       visit_weight=visit_weight,
                                       walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)

            # Set up learning rate
            if FLAGS.learning_rate_type is None:
                t_learning_rate = tf.maximum(
                    tf.train.exponential_decay(FLAGS.learning_rate,
                                               model.step,
                                               FLAGS.decay_steps,
                                               FLAGS.decay_factor,
                                               staircase=True),
                    FLAGS.minimum_learning_rate)
            elif FLAGS.learning_rate_type == 'exp2':
                t_learning_rate = tf.maximum(
                    cyclic_learning_rate(model.step,
                                         FLAGS.minimum_learning_rate,
                                         FLAGS.maximum_learning_rate,
                                         FLAGS.learning_rate_cycle_step,
                                         mode='exp_range',
                                         gamma=0.9999),
                    cyclic_learning_rate(model.step,
                                         FLAGS.minimum_learning_rate,
                                         FLAGS.learning_rate,
                                         FLAGS.learning_rate_cycle_step,
                                         mode='triangular',
                                         gamma=0.9994))

            else:
                t_learning_rate = tf.maximum(
                    cyclic_learning_rate(model.step,
                                         FLAGS.minimum_learning_rate,
                                         FLAGS.learning_rate,
                                         FLAGS.learning_rate_cycle_step,
                                         mode='triangular',
                                         gamma=0.9994),
                    FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            train_op = model.create_train_op(t_learning_rate)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)  # pylint:disable=line-too-long

            final_loss = slim.learning.train(
                train_op,
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
                #session_wrapper=tf_debug.LocalCLIDebugWrapperSession
            )

            print(final_loss)
    return final_loss
def main(_):
  train_images, train_labels = mnist_tools.get_data('train')
  test_images, test_labels = mnist_tools.get_data('test')

  # Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
  sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                         FLAGS.sup_per_class, NUM_LABELS, seed)

  import numpy as np
  if 0:

    indices = [374, 2507, 9755, 12953, 16507, 16873, 23474,
            23909, 30280, 35070, 49603, 50106, 51171, 51726, 51805, 55205, 57251, 57296, 57779, 59154] + \
               [16644, 45576, 52886, 42140, 29201, 7767, 24, 134, 8464, 15022, 15715, 15602, 11030, 3898, 10195, 1454,
                3290, 5293, 5806, 274]
    indices = [374, 2507, 9755, 12953, 16507, 16873, 23474,
            23909, 30280, 35070, 49603, 50106, 51171, 51726, 51805, 55205, 57251, 57296, 57779, 59154] + \
               [9924, 34058, 53476, 15715, 6428, 33598, 33464, 41753, 21250, 26389, 12950,
                12464, 3795, 6761, 5638, 3952, 8300, 5632, 1475, 1875]
    sup_by_label = [ [] for i in range(NUM_LABELS)]
    for ind in indices:
      i = train_labels[ind]
      sup_by_label[i] = sup_by_label[i] + [train_images[ind]]

    for i in range(10):
      sup_by_label[i] = np.asarray(sup_by_label[i])

    sup_by_label = np.asarray(sup_by_label)


  add_random_samples = 20
  if add_random_samples > 0:
    rng = np.random.RandomState()
    indices = rng.choice(len(train_images), add_random_samples, False)

    for i in indices:
      l = train_labels[i]
      sup_by_label[l] = np.vstack([sup_by_label[l], [train_images[i]]])


  graph = tf.Graph()
  with graph.as_default():
    model = semisup.SemisupModel(semisup.architectures.mnist_model, NUM_LABELS,
                                 IMAGE_SHAPE)

    # Set up inputs.
    t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                             FLAGS.unsup_batch_size)
    t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
        sup_by_label, FLAGS.sup_per_batch)

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(t_sup_images)
    t_unsup_emb = model.image_to_embedding(t_unsup_images)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    # Add losses.
    model.add_semisup_loss(
        t_sup_emb, t_unsup_emb, t_sup_labels, visit_weight=FLAGS.visit_weight)
    model.add_logit_loss(t_sup_logit, t_sup_labels)

    t_learning_rate = tf.train.exponential_decay(
        FLAGS.learning_rate,
        model.step,
        FLAGS.decay_steps,
        FLAGS.decay_factor,
        staircase=True)
    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    tf.global_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for step in range(FLAGS.max_steps):
      _, summaries = sess.run([train_op, summary_op])
      if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
        print('Step: %d' % step)
        test_pred = model.classify(test_images).argmax(-1)
        conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
        test_err = (test_labels != test_pred).mean() * 100
        print(conf_mtx)
        print('Test error: %.2f %%' % test_err)
        print()

        test_summary = tf.Summary(
            value=[tf.Summary.Value(
                tag='Test Err', simple_value=test_err)])

        summary_writer.add_summary(summaries, step)
        summary_writer.add_summary(test_summary, step)

        saver.save(sess, FLAGS.logdir, model.step)

    coord.request_stop()
    coord.join(threads)
Exemplo n.º 11
0
def main(argv):
    del argv
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    np.random.seed = seed

    # Load data.
    if FLAGS.dataset.startswith('gs'):
        if FLAGS.target_dataset is not None:
            raise NotImplemented(
                'target dataset is not supported for GS filesystem')
        train_images, train_labels = gs_data.get_data('train', FLAGS.dataset)
        train_images_unlabeled, _ = gs_data.get_data('unlabeled',
                                                     FLAGS.dataset)
        test_images, test_labels = gs_data.get_data('test', FLAGS.dataset)
    else:
        dataset_tools = import_module('tools.' + FLAGS.dataset)
        train_images, train_labels = dataset_tools.get_data('train')
        if FLAGS.target_dataset is not None:
            target_dataset_tools = import_module('tools.' +
                                                 FLAGS.target_dataset)
            train_images_unlabeled, _ = target_dataset_tools.get_data(
                FLAGS.target_dataset_split)
        else:
            train_images_unlabeled, _ = dataset_tools.get_data('unlabeled')

        test_images, test_labels = dataset_tools.get_data('test')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    tokenizer = None
    encoder = None
    nb_words = None
    seq_len = None
    if architecture == semisup.architectures.cnn_sentiment_model:
        if FLAGS.dataset.startswith('gs'):
            dataset_tools = gs_data
            with BytesIO(
                    file_io.read_file_to_string('{}/tokenizer.pkl'.format(
                        FLAGS.dataset))) as fin:
                tokenizer, encoder, nb_words, seq_len = pickle.load(fin)
        else:
            with open(FLAGS.dictionaries, 'rb') as fin:
                tokenizer, encoder, nb_words, seq_len = pickle.load(fin)

        dataset_tools.IMAGE_SHAPE = [seq_len]
        dataset_tools.NUM_LABELS = len(encoder.classes_)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):
            tf.set_random_seed(seed)

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled,
                                                  None,
                                                  FLAGS.unsup_batch_size,
                                                  seed=seed,
                                                  shuffle=FLAGS.shuffle_input)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label,
                FLAGS.sup_per_batch,
                seed=seed,
                shuffle=FLAGS.shuffle_input)

            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) generalize augmentation
                def _random_invert(inputs, _):
                    randu = tf.random_uniform(
                        shape=[FLAGS.sup_per_batch * num_labels],
                        minval=0.,
                        maxval=1.,
                        dtype=tf.float32)
                    randu = tf.cast(tf.less(randu, 0.5), tf.float32)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    inputs = tf.cast(inputs, tf.float32)
                    return tf.abs(inputs - 255 * randu)

                augmentation_function = _random_invert
            else:
                augmentation_function = None

            # Create function that defines the network.
            if architecture == semisup.architectures.cnn_sentiment_model:

                embedding_weights = None
                dim = None
                if FLAGS.w2v is not None:
                    d = dict()
                    for path in FLAGS.w2v.split(','):
                        tf.logging.info('Loading word2vec: {}'.format(path))
                        if path.startswith('gs'):
                            s = file_io.read_file_to_string(path)
                            try:
                                s = unicode(s, errors='replace')
                            except:
                                pass

                            with StringIO(s) as fin:
                                # w2v = KeyedVectors.load_word2vec_format(fin, unicode_errors='ignore')
                                d = utils.load_word2vec(fin, d)
                        else:
                            # w2v = KeyedVectors.load_word2vec_format(path, unicode_errors='ignore')
                            with open(path, 'r') as fin:
                                d = utils.load_word2vec(fin, d)

                        if dim is None:
                            # dim = w2v.vector_size
                            dim = len(d[list(d.keys())[0]])

                            # for w in w2v.vocab:
                            #     d[w] = w2v[w]

                    embedding_weights = utils.get_embedding_weights(
                        nb_words, tokenizer, d, dim)

                if dim is None:
                    dim = FLAGS.word_embedding_dim

                model_function = partial(
                    architecture,
                    nb_words=nb_words,
                    embedding_dim=dim,
                    static_embedding=FLAGS.static_word_embeddings == 0,
                    embedding_weights=embedding_weights,
                    seed=seed,
                    embedding_dropout=FLAGS.embedding_dropout,
                    ############################################################
                    img_shape=image_shape,
                    batch_norm_decay=FLAGS.batch_norm_decay,
                    emb_size=FLAGS.emb_size)
            else:
                model_function = partial(
                    architecture,
                    new_shape=new_shape,
                    img_shape=image_shape,
                    augmentation_function=augmentation_function,
                    batch_norm_decay=FLAGS.batch_norm_decay,
                    emb_size=FLAGS.emb_size)

            ti = tf.constant(test_images,
                             dtype=np.float32,
                             shape=list(test_images.shape),
                             name='test_in')
            tl = tf.constant(test_labels,
                             dtype=np.float32,
                             shape=list(test_labels.shape),
                             name='test_label')
            # Set up semisup model.
            model = semisup.SemisupModel(model_function,
                                         num_labels,
                                         image_shape,
                                         test_in=ti,
                                         test_label=tl)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)
            t_unsup_emb = model.image_to_embedding(t_unsup_images)

            # Add virtual embeddings.
            if FLAGS.virtual_embeddings:
                t_sup_emb = tf.concat(0, [
                    t_sup_emb,
                    semisup.create_virt_emb(FLAGS.virtual_embeddings,
                                            FLAGS.emb_size)
                ])

                if not FLAGS.remove_classes:
                    # need to add additional labels for virtual embeddings
                    t_sup_labels = tf.concat(0, [
                        t_sup_labels,
                        (num_labels +
                         tf.range(1, FLAGS.virtual_embeddings + 1, tf.int64)) *
                        tf.ones([FLAGS.virtual_embeddings], tf.int64)
                    ])

            t_sup_logit = model.embedding_to_logit(t_sup_emb, seed=seed)

            # Add losses.
            visit_weight_envelope_steps = (FLAGS.walker_weight_envelope_steps
                                           if FLAGS.visit_weight_envelope_steps
                                           == -1 else
                                           FLAGS.visit_weight_envelope_steps)
            visit_weight_envelope_delay = (FLAGS.walker_weight_envelope_delay
                                           if FLAGS.visit_weight_envelope_delay
                                           == -1 else
                                           FLAGS.visit_weight_envelope_delay)
            visit_weight = apply_envelope(
                type=FLAGS.visit_weight_envelope,
                step=model.step,
                final_weight=FLAGS.visit_weight,
                growing_steps=visit_weight_envelope_steps,
                delay=visit_weight_envelope_delay)
            walker_weight = apply_envelope(
                type=FLAGS.walker_weight_envelope,
                step=model.step,
                final_weight=FLAGS.walker_weight,
                growing_steps=FLAGS.walker_weight_envelope_steps,  # pylint:disable=line-too-long
                delay=FLAGS.walker_weight_envelope_delay)
            tf.summary.scalar('Weights_Visit', visit_weight)
            tf.summary.scalar('Weights_Walker', walker_weight)

            if FLAGS.unsup_samples != 0:
                model.add_semisup_loss(t_sup_emb,
                                       t_unsup_emb,
                                       t_sup_labels,
                                       visit_weight=visit_weight,
                                       walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)
            model.add_accuracy(t_sup_logit, t_sup_labels)
            model.add_test_accuracy()

            # Set up learning rate
            t_learning_rate = tf.maximum(
                tf.train.exponential_decay(FLAGS.learning_rate,
                                           model.step,
                                           FLAGS.decay_steps,
                                           FLAGS.decay_factor,
                                           staircase=True),
                FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            optimizer = FLAGS.optimizer
            if optimizer == 'None':
                optimizer = None
            train_op = model.create_train_op(t_learning_rate,
                                             optimizer=optimizer)

            if FLAGS.num_cpus > 0:
                num_cpus = FLAGS.num_cpus
                config = tf.ConfigProto(intra_op_parallelism_threads=num_cpus,
                                        inter_op_parallelism_threads=num_cpus,
                                        allow_soft_placement=True,
                                        device_count={'CPU': num_cpus})
            else:
                config = tf.ConfigProto()

            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)  # pylint:disable=line-too-long

            slim.learning.train(
                train_op,
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
            )
Exemplo n.º 12
0
def main(_):

# Sample labeled training subset.

  train_images, train_labels = data.load_training_data(FLAGS.cifar)
  test_images, test_labels = data.load_test_data(FLAGS.cifar)

  print(train_images.shape, train_labels.shape)

  # Sample labeled training subset.
  seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else np.random.randint(0, 1000)
  print('Seed:', seed)
  sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                         FLAGS.sup_per_class, NUM_LABELS, seed)

  graph = tf.Graph()
  with graph.as_default():

    def augment(image):
        # image_size = 28
        # image = tf.image.resize_image_with_crop_or_pad(
        #    image, image_size+4, image_size+4)
        # image = tf.random_crop(image, [image_size, image_size, 3])
        image = tf.image.random_flip_left_right(image)
        # Brightness/saturation/constrast provides small gains .2%~.5% on cifar.
        # image = tf.image.random_brightness(image, max_delta=63. / 255.)
        # image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        # image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
        return image

    model_func = getattr(semisup.architectures, FLAGS.model)
    model = semisup.SemisupModel(model_func, NUM_LABELS,
                                 IMAGE_SHAPE, emb_size=256, dropout_keep_prob=FLAGS.dropout_keep_prob,
                                 augmentation_function=augment)


    if FLAGS.random_batches:
      sup_lbls = np.asarray(np.hstack([np.ones(len(i)) * ind for ind, i in enumerate(sup_by_label)]), np.int)
      sup_images = np.vstack(sup_by_label)
      batch_size = FLAGS.sup_per_batch * NUM_LABELS
      t_sup_images, t_sup_labels = semisup.create_input(sup_images, sup_lbls, batch_size)
    else:
      t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
        sup_by_label, FLAGS.sup_per_batch)

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(t_sup_images)
    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    # Add losses.
    if FLAGS.semisup:
      if FLAGS.equal_cls_unsup:
        allimgs_bylabel = semisup.sample_by_label(train_images, train_labels,
                                                  int(NUM_TRAIN_IMAGES / NUM_LABELS), NUM_LABELS, seed)
        t_unsup_images, _ = semisup.create_per_class_inputs(
          allimgs_bylabel, FLAGS.sup_per_batch)
      else:
        t_unsup_images, _ = semisup.create_input(train_images, train_labels,
                                                 FLAGS.unsup_batch_size)

      t_unsup_emb = model.image_to_embedding(t_unsup_images)
      proximity_weight = tf.placeholder("float", shape=[])
      visit_weight = tf.placeholder("float", shape=[])
      walker_weight = tf.placeholder("float", shape=[])

      model.add_semisup_loss(
        t_sup_emb, t_unsup_emb, t_sup_labels,
        walker_weight=walker_weight, visit_weight=visit_weight, proximity_weight=proximity_weight)
    model.add_logit_loss(t_sup_logit, t_sup_labels)

    t_learning_rate = tf.train.exponential_decay(
      FLAGS.learning_rate,
      model.step,
      FLAGS.decay_steps,
      FLAGS.decay_factor,
      staircase=True)
    train_op = model.create_train_op(t_learning_rate)
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph)

    saver = tf.train.Saver()
  with tf.Session(graph=graph) as sess:
    tf.global_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for step in range(FLAGS.max_steps):
      feed_dict={}
      if FLAGS.semisup:

        if FLAGS.proximity_loss:
          feed_dict = {
            walker_weight:  FLAGS.walker_weight,#0.1 + apply_envelope("lin", step, FLAGS.walker_weight, 500, FLAGS.decay_steps),
            visit_weight: 0,
            proximity_weight: FLAGS.visit_weight#0.1 + apply_envelope("lin", step,FLAGS.visit_weight, 500, FLAGS.decay_steps),
          #t_logit_weight: FLAGS.logit_weight,
          #t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
          }
        else:
          feed_dict =  {
            walker_weight:  FLAGS.walker_weight,#0.1 + apply_envelope("lin", step, FLAGS.walker_weight, 500, FLAGS.decay_steps),
            visit_weight: FLAGS.visit_weight, #0.1 + apply_envelope("lin", step,FLAGS.visit_weight, 500, FLAGS.decay_steps),
            proximity_weight: 0
          #t_logit_weight: FLAGS.logit_weight,
          #t_learning_rate: 5e-5 + apply_envelope("log", step, FLAGS.learning_rate, FLAGS.warmup_steps, 0)
          }

      _, summaries, train_loss = sess.run([train_op, summary_op, model.train_loss], feed_dict)


      if (step + 1) % FLAGS.eval_interval == 0 or step == 99:
        print('Step: %d' % step)
        test_pred = model.classify(test_images, sess).argmax(-1)
        conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS)
        test_err = (test_labels != test_pred).mean() * 100
        print(conf_mtx)
        print('Test error: %.2f %%' % test_err)
        print()

        test_summary = tf.Summary(
            value=[tf.Summary.Value(
                tag='Test Err', simple_value=test_err)])

        summary_writer.add_summary(summaries, step)
        summary_writer.add_summary(test_summary, step)

        #saver.save(sess, FLAGS.logdir, model.step)

    coord.request_stop()
    coord.join(threads)