Beispiel #1
0
def make_eval_ops(train_vars, ema):
  # This does evaluation with and without Polyak averaging.

  images, labels, _ = mnist.load_mnist_as_tensors(
      flatten_images=True, dtype=tf.dtypes.as_dtype(FLAGS.dtype))

  eval_model = Model()
  eval_model(images)  # We need this dummy call because the variables won't
                      # exist otherwise.
  eval_vars = eval_model.variables

  update_eval_model = group_assign(eval_vars, train_vars)

  with tf.control_dependencies([update_eval_model]):
    logits = eval_model(images)
    eval_loss, eval_error = compute_loss(
        logits=logits, labels=labels, return_error=True)

    with tf.control_dependencies([eval_loss, eval_error]):
      update_eval_model_avg = group_assign(
          eval_vars, (ema.average(t) for t in train_vars))

      with tf.control_dependencies([update_eval_model_avg]):
        logits = eval_model(images)
        eval_loss_avg, eval_error_avg = compute_loss(
            logits=logits, labels=labels, return_error=True)

  return eval_loss, eval_error, eval_loss_avg, eval_error_avg
Beispiel #2
0
def load_mnist():
    """Creates MNIST dataset and wraps it inside cached data reader.

  Returns:
    cached_reader: `data_reader.CachedReader` instance which wraps MNIST
      dataset.
    num_examples: int. The number of training examples.
  """
    # Wrap the data set into cached_reader which provides variable sized training
    # and caches the read train batch.

    if not FLAGS.use_alt_data_reader:
        # Version 1 using data_reader.py (slow!)
        dataset, num_examples = mnist.load_mnist_as_dataset(
            flatten_images=False)
        if FLAGS.use_batch_size_schedule:
            max_batch_size = num_examples
        else:
            max_batch_size = FLAGS.batch_size

        # Shuffle before repeat is correct unless you want repeat cases in the
        # same batch.
        dataset = (dataset.shuffle(num_examples).repeat().batch(
            max_batch_size).prefetch(5))
        dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()

        # This version of CachedDataReader requires the dataset to be shuffled
        return data_reader.CachedDataReader(dataset,
                                            max_batch_size), num_examples

    else:
        # Version 2 using data_reader_alt.py (faster)
        images, labels, num_examples = mnist.load_mnist_as_tensors(
            flatten_images=False)
        dataset = (images, labels)

        # This version of CachedDataReader requires the dataset to NOT be shuffled
        return data_reader_alt.CachedDataReader(dataset,
                                                num_examples), num_examples
Beispiel #3
0
def construct_train_quants():

    with tf.device(FLAGS.device):
        # Load dataset.
        cached_reader, num_examples = load_mnist()
        batch_size_schedule = _get_batch_size_schedule(num_examples)
        batch_size = tf.placeholder(shape=(),
                                    dtype=tf.int32,
                                    name='batch_size')

        minibatch = cached_reader(batch_size)
        training_model = Model()
        layer_collection = kfac.LayerCollection()

        if FLAGS.use_sua_approx:
            layer_collection.set_default_conv2d_approximation('kron_sua')

        ema = tf.train.ExponentialMovingAverage(FLAGS.polyak_decay,
                                                zero_debias=True)

        def loss_fn(minibatch, layer_collection=None, return_error=False):
            features, labels = minibatch
            logits = training_model(features)
            return compute_loss(logits=logits,
                                labels=labels,
                                layer_collection=layer_collection,
                                return_error=return_error)

        (batch_loss, batch_error) = loss_fn(minibatch,
                                            layer_collection=layer_collection,
                                            return_error=True)

        train_vars = training_model.variables

        # Make training op:
        train_op, opt = make_train_op(
            minibatch,
            batch_size,
            batch_loss,
            layer_collection,
            loss_fn=loss_fn,
            prev_train_batch=cached_reader.cached_batch)

        with tf.control_dependencies([train_op]):
            train_op = ema.apply(train_vars)

        # Make eval ops:
        images, labels, num_examples = mnist.load_mnist_as_tensors(
            flatten_images=True)

        eval_model = Model()
        eval_model(
            images)  # We need this dummy call because for some reason the
        # variables won't exist otherwise...
        eval_vars = eval_model.variables

        update_eval_model = group_assign(eval_vars, train_vars)

        with tf.control_dependencies([update_eval_model]):
            logits = eval_model(images)
            eval_loss, eval_error = compute_loss(logits=logits,
                                                 labels=labels,
                                                 return_error=True)

            with tf.control_dependencies([eval_loss, eval_error]):
                update_eval_model_avg = group_assign(eval_vars,
                                                     (ema.average(t)
                                                      for t in train_vars))

                with tf.control_dependencies([update_eval_model_avg]):
                    logits = eval_model(images)
                    eval_loss_avg, eval_error_avg = compute_loss(
                        logits=logits, labels=labels, return_error=True)

    return (train_op, opt, batch_loss, batch_error, batch_size_schedule,
            batch_size, eval_loss, eval_error, eval_loss_avg, eval_error_avg)