Пример #1
0
    def test_loss(self, dist_norm, top_k, loss_type):
        image_shape = (12, 12, 1)
        num_classes = 10
        batch_size = 3
        images = tf.convert_to_tensor(np.random.rand(*((batch_size, ) +
                                                       image_shape)),
                                      dtype=tf.float32)
        labels = tf.convert_to_tensor(np.random.randint(0,
                                                        high=num_classes,
                                                        size=batch_size),
                                      dtype=tf.int32)
        # Toy model.
        endpoints = {}
        endpoints["input_layer"] = images
        # Convolution layer.
        net = tf.layers.conv2d(images,
                               filters=8,
                               kernel_size=3,
                               strides=(1, 1),
                               padding="same",
                               activation=tf.nn.relu)
        endpoints["conv_layer"] = net
        # Global average pooling layer.
        net = tf.reduce_mean(net, axis=[1, 2])
        # Output layer.
        logits = tf.layers.dense(net, num_classes)
        loss = margin_loss.large_margin(
            logits=logits,
            one_hot_labels=tf.one_hot(labels, num_classes),
            layers_list=[endpoints["input_layer"], endpoints["conv_layer"]],
            gamma=10000,
            alpha_factor=4,
            top_k=top_k,
            dist_norm=dist_norm,
            loss_type=loss_type)
        var_list = tf.global_variables()
        init = tf.global_variables_initializer()

        # Test gradients are not None.
        gs = tf.gradients(loss, var_list)
        for g in gs:
            self.assertIsNotNone(g)

        # Test loss shape.
        with self.test_session() as sess:
            sess.run(init)
            self.assertEqual(sess.run(loss).shape, ())
Пример #2
0
def train():
    """Training function."""
    is_chief = (FLAGS.task == 0)
    g = tf.Graph()
    with g.as_default():
        with tf.device(
                tf.train.replica_device_setter(ps_tasks=FLAGS.ps_tasks)):
            if FLAGS.experiment_type == "mnist":
                config = mnist_config.ConfigDict()
                dataset = mnist.MNIST(data_dir=FLAGS.data_dir,
                                      subset="train",
                                      batch_size=FLAGS.batch_size,
                                      is_training=True)
                model = mnist_model.MNISTNetwork(config)
                layers_names = [
                    "conv_layer%d" % i
                    for i in range(len(config.filter_sizes_conv_layers))
                ]

            images, labels, num_examples, num_classes = (dataset.images,
                                                         dataset.labels,
                                                         dataset.num_examples,
                                                         dataset.num_classes)
            tf.summary.image("images", images)

            # Build model.
            logits, endpoints = model(images, is_training=True)
            layers_list = [images] + [endpoints[name] for name in layers_names]

            # Define losses.
            l2_loss_wt = config.l2_loss_wt
            xent_loss_wt = config.xent_loss_wt
            margin_loss_wt = config.margin_loss_wt
            gamma = config.gamma
            alpha = config.alpha
            top_k = config.top_k
            dist_norm = config.dist_norm
            with tf.name_scope("losses"):
                xent_loss = tf.reduce_mean(
                    tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=logits, labels=labels))
                margin = margin_loss.large_margin(
                    logits=logits,
                    one_hot_labels=tf.one_hot(labels, num_classes),
                    layers_list=layers_list,
                    gamma=gamma,
                    alpha_factor=alpha,
                    top_k=top_k,
                    dist_norm=dist_norm,
                    epsilon=1e-6,
                    layers_weights=[
                        np.prod(layer.get_shape().as_list()[1:])
                        for layer in layers_list
                    ] if np.isinf(dist_norm) else None)

                l2_loss = 0.
                for v in tf.trainable_variables():
                    tf.logging.info(v)
                    l2_loss += tf.nn.l2_loss(v)

                total_loss = 0
                if xent_loss_wt > 0:
                    total_loss += xent_loss_wt * xent_loss
                if margin_loss_wt > 0:
                    total_loss += margin_loss_wt * margin
                if l2_loss_wt > 0:
                    total_loss += l2_loss_wt * l2_loss

                tf.summary.scalar("xent_loss", xent_loss)
                tf.summary.scalar("margin_loss", margin)
                tf.summary.scalar("l2_loss", l2_loss)
                tf.summary.scalar("total_loss", total_loss)

            # Build optimizer.
            init_lr = config.init_lr
            with tf.name_scope("optimizer"):
                global_step = tf.train.get_or_create_global_step()
                if FLAGS.num_replicas > 1:
                    num_batches_per_epoch = num_examples // (
                        FLAGS.batch_size * FLAGS.num_replicas)
                else:
                    num_batches_per_epoch = num_examples // FLAGS.batch_size
                max_iters = num_batches_per_epoch * FLAGS.num_epochs

                lr = tf.train.exponential_decay(init_lr,
                                                global_step,
                                                FLAGS.decay_steps,
                                                FLAGS.decay_rate,
                                                staircase=True,
                                                name="lr_schedule")

                tf.summary.scalar("learning_rate", lr)

                var_list = tf.trainable_variables()
                grad_vars = tf.gradients(total_loss, var_list)
                tf.summary.scalar(
                    "grad_norm",
                    tf.reduce_mean(
                        [tf.norm(grad_var) for grad_var in grad_vars]))
                grad_vars, _ = tf.clip_by_global_norm(grad_vars, 5.0)

                opt = tf.train.RMSPropOptimizer(lr,
                                                momentum=FLAGS.momentum,
                                                epsilon=1e-2)
                if FLAGS.num_replicas > 1:
                    opt = tf.train.SyncReplicasOptimizer(
                        opt,
                        replicas_to_aggregate=FLAGS.num_replicas,
                        total_num_replicas=FLAGS.num_replicas)

                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                with tf.control_dependencies(update_ops):
                    opt_op = opt.apply_gradients(zip(grad_vars, var_list),
                                                 global_step=global_step)

            # Compute accuracy.
            top1_op = tf.nn.in_top_k(logits, labels, 1)
            accuracy = tf.reduce_mean(tf.cast(top1_op, dtype=tf.float32))
            tf.summary.scalar("top1_accuracy", accuracy)

            # Prepare optimization.
            vars_to_save = tf.global_variables()
            saver = tf.train.Saver(var_list=vars_to_save,
                                   max_to_keep=5,
                                   sharded=True)
            merged_summary = tf.summary.merge_all()

            # Hooks for optimization.
            hooks = [tf.train.StopAtStepHook(last_step=max_iters)]
            if not is_chief:
                hooks.append(
                    tf.train.GlobalStepWaiterHook(FLAGS.task *
                                                  FLAGS.startup_delay_steps))

            init_op = tf.global_variables_initializer()
            scaffold = tf.train.Scaffold(init_op=init_op,
                                         summary_op=merged_summary,
                                         saver=saver)

            # Run optimization.
            epoch = 0
            with tf.train.MonitoredTrainingSession(
                    is_chief=is_chief,
                    checkpoint_dir=FLAGS.checkpoint_dir,
                    hooks=hooks,
                    save_checkpoint_secs=FLAGS.save_checkpoint_secs,
                    save_summaries_secs=FLAGS.save_summaries_secs,
                    scaffold=scaffold) as sess:
                while not sess.should_stop():
                    _, acc, i = sess.run((opt_op, accuracy, global_step))
                    epoch = i // num_batches_per_epoch
                    if (i % FLAGS.log_every_steps) == 0:
                        tf.logging.info(
                            "global step %d: epoch %d:\n train accuracy %.3f" %
                            (i, epoch, acc))
Пример #3
0
def evaluate():
    """Evaluating function."""
    g = tf.Graph()
    ops_dict = {}
    with g.as_default():
        # Data set.
        if FLAGS.experiment_type == "mnist":
            config = mnist_config.ConfigDict()
            dataset = mnist.MNIST(data_dir=FLAGS.data_dir,
                                  subset=FLAGS.subset,
                                  batch_size=FLAGS.batch_size,
                                  is_training=False)
            model = mnist_model.MNISTNetwork(config)

        images, labels, num_examples, num_classes = (dataset.images,
                                                     dataset.labels,
                                                     dataset.num_examples,
                                                     dataset.num_classes)

        logits, _ = model(images, is_training=False)

        top1_op = tf.nn.in_top_k(logits, labels, 1)

        top1_op = tf.cast(top1_op, dtype=tf.float32)
        ops_dict["top1"] = (None, top1_op)
        accuracy_ph = tf.placeholder(tf.float32, None)
        ops_dict["top1_accuracy"] = (accuracy_ph, None)
        tf.summary.scalar("top1_accuracy", accuracy_ph)

        with tf.name_scope("optimizer"):
            global_step = tf.train.get_or_create_global_step()

        # Define losses.
        l2_loss_wt = config.l2_loss_wt
        xent_loss_wt = config.xent_loss_wt
        margin_loss_wt = config.margin_loss_wt
        gamma = config.gamma
        alpha = config.alpha
        top_k = config.top_k
        dist_norm = config.dist_norm
        with tf.name_scope("losses"):
            xent_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                               labels=labels))
            margin = margin_loss.large_margin(logits=logits,
                                              one_hot_labels=tf.one_hot(
                                                  labels, num_classes),
                                              layers_list=[images],
                                              gamma=gamma,
                                              alpha_factor=alpha,
                                              top_k=top_k,
                                              dist_norm=dist_norm)
            l2_loss = 0.
            for v in tf.trainable_variables():
                tf.logging.info(v)
                l2_loss += tf.nn.l2_loss(v)

            total_loss = 0
            if xent_loss_wt > 0:
                total_loss += xent_loss_wt * xent_loss
            if margin_loss_wt > 0:
                total_loss += margin_loss_wt * margin
            if l2_loss_wt > 0:
                total_loss += l2_loss_wt * l2_loss

            xent_loss_ph = tf.placeholder(tf.float32, None)
            margin_loss_ph = tf.placeholder(tf.float32, None)
            l2_loss_ph = tf.placeholder(tf.float32, None)
            total_loss_ph = tf.placeholder(tf.float32, None)
            tf.summary.scalar("xent_loss", xent_loss_ph)
            tf.summary.scalar("margin_loss", margin_loss_ph)
            tf.summary.scalar("l2_loss", l2_loss_ph)
            tf.summary.scalar("total_loss", total_loss_ph)

            ops_dict["losses/xent_loss"] = (xent_loss_ph, xent_loss)
            ops_dict["losses/margin_loss"] = (margin_loss_ph, margin)
            ops_dict["losses/l2_loss"] = (l2_loss_ph, l2_loss)
            ops_dict["losses/total_loss"] = (total_loss_ph, total_loss)

        # Prepare evaluation session.
        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir,
                                               tf.get_default_graph())
        vars_to_save = tf.global_variables()
        saver = tf.train.Saver(var_list=vars_to_save)
        scaffold = tf.train.Scaffold(saver=saver)
        session_creator = tf.train.ChiefSessionCreator(
            scaffold=scaffold, checkpoint_dir=FLAGS.checkpoint_dir)
        while True:
            _eval_once(session_creator, ops_dict, summary_writer,
                       merged_summary, global_step, num_examples)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)