def test_reuse_model(self, is_training):
    batch_size = 5
    images = _construct_images(batch_size)
    config = mnist_config.ConfigDict()

    model = mnist_model.MNISTNetwork(config)

    # Build once.
    logits1, _ = model(images, is_training)
    num_params = len(tf.all_variables())
    l2_loss1 = tf.losses.get_regularization_loss()
    # Build twice.
    logits2, _ = model(images, is_training)
    l2_loss2 = tf.losses.get_regularization_loss()

    # Ensure variables are reused.
    self.assertLen(tf.all_variables(), num_params)
    init = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init)
      # Ensure operations are the same after reuse.
      err_logits = (np.abs(sess.run(logits1 - logits2))).sum()
      self.assertAlmostEqual(err_logits, 0, 9)
      err_losses = (np.abs(sess.run(l2_loss1 - l2_loss2))).sum()
      self.assertAlmostEqual(err_losses, 0, 9)
  def test_build_model(self, is_training):
    batch_size = 5
    num_classes = 10
    images = _construct_images(batch_size)
    config = mnist_config.ConfigDict()
    config.num_classes = num_classes

    model = mnist_model.MNISTNetwork(config)

    logits, _ = model(images, is_training)

    final_shape = (batch_size, num_classes)
    init = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init)
      self.assertEqual(final_shape, sess.run(logits).shape)
Beispiel #3
0
def train():
    """Training function."""
    is_chief = (FLAGS.task == 0)
    g = tf.Graph()
    with g.as_default():
        with tf.device(
                tf.train.replica_device_setter(ps_tasks=FLAGS.ps_tasks)):
            if FLAGS.experiment_type == "mnist":
                config = mnist_config.ConfigDict()
                dataset = mnist.MNIST(data_dir=FLAGS.data_dir,
                                      subset="train",
                                      batch_size=FLAGS.batch_size,
                                      is_training=True)
                model = mnist_model.MNISTNetwork(config)
                layers_names = [
                    "conv_layer%d" % i
                    for i in range(len(config.filter_sizes_conv_layers))
                ]

            images, labels, num_examples, num_classes = (dataset.images,
                                                         dataset.labels,
                                                         dataset.num_examples,
                                                         dataset.num_classes)
            tf.summary.image("images", images)

            # Build model.
            logits, endpoints = model(images, is_training=True)
            layers_list = [images] + [endpoints[name] for name in layers_names]

            # Define losses.
            l2_loss_wt = config.l2_loss_wt
            xent_loss_wt = config.xent_loss_wt
            margin_loss_wt = config.margin_loss_wt
            gamma = config.gamma
            alpha = config.alpha
            top_k = config.top_k
            dist_norm = config.dist_norm
            with tf.name_scope("losses"):
                xent_loss = tf.reduce_mean(
                    tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=logits, labels=labels))
                margin = margin_loss.large_margin(
                    logits=logits,
                    one_hot_labels=tf.one_hot(labels, num_classes),
                    layers_list=layers_list,
                    gamma=gamma,
                    alpha_factor=alpha,
                    top_k=top_k,
                    dist_norm=dist_norm,
                    epsilon=1e-6,
                    layers_weights=[
                        np.prod(layer.get_shape().as_list()[1:])
                        for layer in layers_list
                    ] if np.isinf(dist_norm) else None)

                l2_loss = 0.
                for v in tf.trainable_variables():
                    tf.logging.info(v)
                    l2_loss += tf.nn.l2_loss(v)

                total_loss = 0
                if xent_loss_wt > 0:
                    total_loss += xent_loss_wt * xent_loss
                if margin_loss_wt > 0:
                    total_loss += margin_loss_wt * margin
                if l2_loss_wt > 0:
                    total_loss += l2_loss_wt * l2_loss

                tf.summary.scalar("xent_loss", xent_loss)
                tf.summary.scalar("margin_loss", margin)
                tf.summary.scalar("l2_loss", l2_loss)
                tf.summary.scalar("total_loss", total_loss)

            # Build optimizer.
            init_lr = config.init_lr
            with tf.name_scope("optimizer"):
                global_step = tf.train.get_or_create_global_step()
                if FLAGS.num_replicas > 1:
                    num_batches_per_epoch = num_examples // (
                        FLAGS.batch_size * FLAGS.num_replicas)
                else:
                    num_batches_per_epoch = num_examples // FLAGS.batch_size
                max_iters = num_batches_per_epoch * FLAGS.num_epochs

                lr = tf.train.exponential_decay(init_lr,
                                                global_step,
                                                FLAGS.decay_steps,
                                                FLAGS.decay_rate,
                                                staircase=True,
                                                name="lr_schedule")

                tf.summary.scalar("learning_rate", lr)

                var_list = tf.trainable_variables()
                grad_vars = tf.gradients(total_loss, var_list)
                tf.summary.scalar(
                    "grad_norm",
                    tf.reduce_mean(
                        [tf.norm(grad_var) for grad_var in grad_vars]))
                grad_vars, _ = tf.clip_by_global_norm(grad_vars, 5.0)

                opt = tf.train.RMSPropOptimizer(lr,
                                                momentum=FLAGS.momentum,
                                                epsilon=1e-2)
                if FLAGS.num_replicas > 1:
                    opt = tf.train.SyncReplicasOptimizer(
                        opt,
                        replicas_to_aggregate=FLAGS.num_replicas,
                        total_num_replicas=FLAGS.num_replicas)

                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                with tf.control_dependencies(update_ops):
                    opt_op = opt.apply_gradients(zip(grad_vars, var_list),
                                                 global_step=global_step)

            # Compute accuracy.
            top1_op = tf.nn.in_top_k(logits, labels, 1)
            accuracy = tf.reduce_mean(tf.cast(top1_op, dtype=tf.float32))
            tf.summary.scalar("top1_accuracy", accuracy)

            # Prepare optimization.
            vars_to_save = tf.global_variables()
            saver = tf.train.Saver(var_list=vars_to_save,
                                   max_to_keep=5,
                                   sharded=True)
            merged_summary = tf.summary.merge_all()

            # Hooks for optimization.
            hooks = [tf.train.StopAtStepHook(last_step=max_iters)]
            if not is_chief:
                hooks.append(
                    tf.train.GlobalStepWaiterHook(FLAGS.task *
                                                  FLAGS.startup_delay_steps))

            init_op = tf.global_variables_initializer()
            scaffold = tf.train.Scaffold(init_op=init_op,
                                         summary_op=merged_summary,
                                         saver=saver)

            # Run optimization.
            epoch = 0
            with tf.train.MonitoredTrainingSession(
                    is_chief=is_chief,
                    checkpoint_dir=FLAGS.checkpoint_dir,
                    hooks=hooks,
                    save_checkpoint_secs=FLAGS.save_checkpoint_secs,
                    save_summaries_secs=FLAGS.save_summaries_secs,
                    scaffold=scaffold) as sess:
                while not sess.should_stop():
                    _, acc, i = sess.run((opt_op, accuracy, global_step))
                    epoch = i // num_batches_per_epoch
                    if (i % FLAGS.log_every_steps) == 0:
                        tf.logging.info(
                            "global step %d: epoch %d:\n train accuracy %.3f" %
                            (i, epoch, acc))
Beispiel #4
0
def evaluate():
    """Evaluating function."""
    g = tf.Graph()
    ops_dict = {}
    with g.as_default():
        # Data set.
        if FLAGS.experiment_type == "mnist":
            config = mnist_config.ConfigDict()
            dataset = mnist.MNIST(data_dir=FLAGS.data_dir,
                                  subset=FLAGS.subset,
                                  batch_size=FLAGS.batch_size,
                                  is_training=False)
            model = mnist_model.MNISTNetwork(config)

        images, labels, num_examples, num_classes = (dataset.images,
                                                     dataset.labels,
                                                     dataset.num_examples,
                                                     dataset.num_classes)

        logits, _ = model(images, is_training=False)

        top1_op = tf.nn.in_top_k(logits, labels, 1)

        top1_op = tf.cast(top1_op, dtype=tf.float32)
        ops_dict["top1"] = (None, top1_op)
        accuracy_ph = tf.placeholder(tf.float32, None)
        ops_dict["top1_accuracy"] = (accuracy_ph, None)
        tf.summary.scalar("top1_accuracy", accuracy_ph)

        with tf.name_scope("optimizer"):
            global_step = tf.train.get_or_create_global_step()

        # Define losses.
        l2_loss_wt = config.l2_loss_wt
        xent_loss_wt = config.xent_loss_wt
        margin_loss_wt = config.margin_loss_wt
        gamma = config.gamma
        alpha = config.alpha
        top_k = config.top_k
        dist_norm = config.dist_norm
        with tf.name_scope("losses"):
            xent_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                               labels=labels))
            margin = margin_loss.large_margin(logits=logits,
                                              one_hot_labels=tf.one_hot(
                                                  labels, num_classes),
                                              layers_list=[images],
                                              gamma=gamma,
                                              alpha_factor=alpha,
                                              top_k=top_k,
                                              dist_norm=dist_norm)
            l2_loss = 0.
            for v in tf.trainable_variables():
                tf.logging.info(v)
                l2_loss += tf.nn.l2_loss(v)

            total_loss = 0
            if xent_loss_wt > 0:
                total_loss += xent_loss_wt * xent_loss
            if margin_loss_wt > 0:
                total_loss += margin_loss_wt * margin
            if l2_loss_wt > 0:
                total_loss += l2_loss_wt * l2_loss

            xent_loss_ph = tf.placeholder(tf.float32, None)
            margin_loss_ph = tf.placeholder(tf.float32, None)
            l2_loss_ph = tf.placeholder(tf.float32, None)
            total_loss_ph = tf.placeholder(tf.float32, None)
            tf.summary.scalar("xent_loss", xent_loss_ph)
            tf.summary.scalar("margin_loss", margin_loss_ph)
            tf.summary.scalar("l2_loss", l2_loss_ph)
            tf.summary.scalar("total_loss", total_loss_ph)

            ops_dict["losses/xent_loss"] = (xent_loss_ph, xent_loss)
            ops_dict["losses/margin_loss"] = (margin_loss_ph, margin)
            ops_dict["losses/l2_loss"] = (l2_loss_ph, l2_loss)
            ops_dict["losses/total_loss"] = (total_loss_ph, total_loss)

        # Prepare evaluation session.
        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir,
                                               tf.get_default_graph())
        vars_to_save = tf.global_variables()
        saver = tf.train.Saver(var_list=vars_to_save)
        scaffold = tf.train.Scaffold(saver=saver)
        session_creator = tf.train.ChiefSessionCreator(
            scaffold=scaffold, checkpoint_dir=FLAGS.checkpoint_dir)
        while True:
            _eval_once(session_creator, ops_dict, summary_writer,
                       merged_summary, global_step, num_examples)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)