def make_eval_ops(train_vars, ema): # This does evaluation with and without Polyak averaging. images, labels, _ = mnist.load_mnist_as_tensors( flatten_images=True, dtype=tf.dtypes.as_dtype(FLAGS.dtype)) eval_model = Model() eval_model(images) # We need this dummy call because the variables won't # exist otherwise. eval_vars = eval_model.variables update_eval_model = group_assign(eval_vars, train_vars) with tf.control_dependencies([update_eval_model]): logits = eval_model(images) eval_loss, eval_error = compute_loss( logits=logits, labels=labels, return_error=True) with tf.control_dependencies([eval_loss, eval_error]): update_eval_model_avg = group_assign( eval_vars, (ema.average(t) for t in train_vars)) with tf.control_dependencies([update_eval_model_avg]): logits = eval_model(images) eval_loss_avg, eval_error_avg = compute_loss( logits=logits, labels=labels, return_error=True) return eval_loss, eval_error, eval_loss_avg, eval_error_avg
def load_mnist(): """Creates MNIST dataset and wraps it inside cached data reader. Returns: cached_reader: `data_reader.CachedReader` instance which wraps MNIST dataset. num_examples: int. The number of training examples. """ # Wrap the data set into cached_reader which provides variable sized training # and caches the read train batch. if not FLAGS.use_alt_data_reader: # Version 1 using data_reader.py (slow!) dataset, num_examples = mnist.load_mnist_as_dataset( flatten_images=False) if FLAGS.use_batch_size_schedule: max_batch_size = num_examples else: max_batch_size = FLAGS.batch_size # Shuffle before repeat is correct unless you want repeat cases in the # same batch. dataset = (dataset.shuffle(num_examples).repeat().batch( max_batch_size).prefetch(5)) dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() # This version of CachedDataReader requires the dataset to be shuffled return data_reader.CachedDataReader(dataset, max_batch_size), num_examples else: # Version 2 using data_reader_alt.py (faster) images, labels, num_examples = mnist.load_mnist_as_tensors( flatten_images=False) dataset = (images, labels) # This version of CachedDataReader requires the dataset to NOT be shuffled return data_reader_alt.CachedDataReader(dataset, num_examples), num_examples
def construct_train_quants(): with tf.device(FLAGS.device): # Load dataset. cached_reader, num_examples = load_mnist() batch_size_schedule = _get_batch_size_schedule(num_examples) batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size') minibatch = cached_reader(batch_size) training_model = Model() layer_collection = kfac.LayerCollection() if FLAGS.use_sua_approx: layer_collection.set_default_conv2d_approximation('kron_sua') ema = tf.train.ExponentialMovingAverage(FLAGS.polyak_decay, zero_debias=True) def loss_fn(minibatch, layer_collection=None, return_error=False): features, labels = minibatch logits = training_model(features) return compute_loss(logits=logits, labels=labels, layer_collection=layer_collection, return_error=return_error) (batch_loss, batch_error) = loss_fn(minibatch, layer_collection=layer_collection, return_error=True) train_vars = training_model.variables # Make training op: train_op, opt = make_train_op( minibatch, batch_size, batch_loss, layer_collection, loss_fn=loss_fn, prev_train_batch=cached_reader.cached_batch) with tf.control_dependencies([train_op]): train_op = ema.apply(train_vars) # Make eval ops: images, labels, num_examples = mnist.load_mnist_as_tensors( flatten_images=True) eval_model = Model() eval_model( images) # We need this dummy call because for some reason the # variables won't exist otherwise... eval_vars = eval_model.variables update_eval_model = group_assign(eval_vars, train_vars) with tf.control_dependencies([update_eval_model]): logits = eval_model(images) eval_loss, eval_error = compute_loss(logits=logits, labels=labels, return_error=True) with tf.control_dependencies([eval_loss, eval_error]): update_eval_model_avg = group_assign(eval_vars, (ema.average(t) for t in train_vars)) with tf.control_dependencies([update_eval_model_avg]): logits = eval_model(images) eval_loss_avg, eval_error_avg = compute_loss( logits=logits, labels=labels, return_error=True) return (train_op, opt, batch_loss, batch_error, batch_size_schedule, batch_size, eval_loss, eval_error, eval_loss_avg, eval_error_avg)