Exemplo n.º 1
0
 def _benchmark_eager_apply(self,
                            label,
                            device_and_format,
                            defun=False,
                            execution_mode=None,
                            compiled=False):
   config = config_.get_hparams_imagenet_56()
   with tfe.execution_mode(execution_mode):
     device, data_format = device_and_format
     model = revnet.RevNet(config=config)
     if defun:
       model.call = tfe.defun(model.call, compiled=compiled)
     batch_size = 64
     num_burn = 5
     num_iters = 10
     with tf.device(device):
       images, _ = random_batch(batch_size, config)
       for _ in range(num_burn):
         model(images, training=False)
       if execution_mode:
         tfe.async_wait()
       gc.collect()
       start = time.time()
       for _ in range(num_iters):
         model(images, training=False)
       if execution_mode:
         tfe.async_wait()
       self._report(label, start, num_iters, device, batch_size, data_format)
Exemplo n.º 2
0
  def test_training_graph(self):
    """Test model training in graph mode."""

    with tf.Graph().as_default():
      x = tf.random_normal(
          shape=(self.config.batch_size,) + self.config.input_shape)
      t = tf.random_uniform(
          shape=(self.config.batch_size,),
          minval=0,
          maxval=self.config.n_classes,
          dtype=tf.int32)
      global_step = tfe.Variable(0., trainable=False)
      model = revnet.RevNet(config=self.config)
      grads_all, vars_all, _ = model.compute_gradients(x, t, training=True)
      optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
      updates = model.get_updates_for(x)
      self.assertEqual(len(updates), 192)
      with tf.control_dependencies(model.get_updates_for(x)):
        train_op = optimizer.apply_gradients(
            zip(grads_all, vars_all), global_step=global_step)

      with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for _ in range(1):
          sess.run(train_op)
Exemplo n.º 3
0
def model_fn(features, labels, mode, params):
  """Function specifying the model that is required by the `tf.estimator` API.

  Args:
    features: Input images
    labels: Labels of images
    mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT'
    params: A dictionary of extra parameter that might be passed

  Returns:
    An instance of `tf.estimator.EstimatorSpec`
  """

  inputs = features
  if isinstance(inputs, dict):
    inputs = features["image"]

  config = params["config"]
  model = revnet.RevNet(config=config)

  if mode == tf.estimator.ModeKeys.TRAIN:
    global_step = tf.train.get_or_create_global_step()
    learning_rate = tf.train.piecewise_constant(
        global_step, config.lr_decay_steps, config.lr_list)
    optimizer = tf.train.MomentumOptimizer(
        learning_rate, momentum=config.momentum)
    logits, saved_hidden = model(inputs, training=True)
    grads, loss = model.compute_gradients(saved_hidden, labels, training=True)
    with tf.control_dependencies(model.get_updates_for(inputs)):
      train_op = optimizer.apply_gradients(
          zip(grads, model.trainable_variables), global_step=global_step)

    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
  else:
    logits, _ = model(inputs, training=False)
    predictions = tf.argmax(logits, axis=1)
    probabilities = tf.nn.softmax(logits)

    if mode == tf.estimator.ModeKeys.EVAL:
      loss = model.compute_loss(labels=labels, logits=logits)
      return tf.estimator.EstimatorSpec(
          mode=mode,
          loss=loss,
          eval_metric_ops={
              "accuracy":
                  tf.metrics.accuracy(labels=labels, predictions=predictions)
          })

    else:  # mode == tf.estimator.ModeKeys.PREDICT
      result = {
          "classes": predictions,
          "probabilities": probabilities,
      }

      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=predictions,
          export_outputs={
              "classify": tf.estimator.export.PredictOutput(result)
          })
Exemplo n.º 4
0
  def test_training_graph(self):
    """Test model training in graph mode."""
    with tf.Graph().as_default():
      config = config_.get_hparams_cifar_38()
      config.add_hparam("n_classes", 10)
      config.add_hparam("dataset", "cifar-10")

      x = tf.random_normal(
          shape=(self.config.batch_size,) + self.config.input_shape)
      t = tf.random_uniform(
          shape=(self.config.batch_size,),
          minval=0,
          maxval=self.config.n_classes,
          dtype=tf.int32)
      global_step = tf.Variable(0., trainable=False)
      model = revnet.RevNet(config=config)
      _, saved_hidden = model(x)
      grads, _ = model.compute_gradients(saved_hidden=saved_hidden, labels=t)
      optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
      train_op = optimizer.apply_gradients(
          zip(grads, model.trainable_variables), global_step=global_step)

      with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for _ in range(1):
          sess.run(train_op)
Exemplo n.º 5
0
 def setUp(self):
     super(RevnetTest, self).setUp()
     config = config_.get_hparams_imagenet_56()
     shape = (config.batch_size, ) + config.input_shape
     self.model = revnet.RevNet(config=config)
     self.x = tf.random_normal(shape=shape)
     self.t = tf.random_uniform(shape=[config.batch_size],
                                minval=0,
                                maxval=config.n_classes,
                                dtype=tf.int32)
     self.config = config
Exemplo n.º 6
0
  def test_train_step_defun(self):
    self.model.call = tfe.defun(self.model.call)
    logits, _ = self.model(self.x, training=True)
    loss = self.model.compute_loss(logits=logits, labels=self.t)
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)

    for _ in range(3):
      loss_ = self.model.train_step(self.x, self.t, optimizer, report=True)
      self.assertTrue(loss_.numpy() <= loss.numpy())
      loss = loss_

    # Initialize new model, so that other tests are not affected
    self.model = revnet.RevNet(config=self.config)
Exemplo n.º 7
0
 def setUp(self):
     super(RevNetTest, self).setUp()
     config = config_.get_hparams_cifar_38()
     # Reconstruction could cause numerical error, use double precision for tests
     config.dtype = tf.float64
     config.fused = False  # Fused batch norm does not support tf.float64
     shape = (config.batch_size, ) + config.input_shape
     self.model = revnet.RevNet(config=config)
     self.x = tf.random_normal(shape=shape, dtype=tf.float64)
     self.t = tf.random_uniform(shape=[config.batch_size],
                                minval=0,
                                maxval=config.n_classes,
                                dtype=tf.int64)
     self.config = config
Exemplo n.º 8
0
 def setUp(self):
   super(RevNetTest, self).setUp()
   config = config_.get_hparams_cifar_38()
   config.add_hparam("n_classes", 10)
   config.add_hparam("dataset", "cifar-10")
   # Reconstruction could cause numerical error, use double precision for tests
   config.dtype = tf.float64
   config.fused = False  # Fused batch norm does not support tf.float64
   # Reduce the batch size for tests because the OSS version runs
   # in constrained GPU environment with 1-2GB of memory.
   config.batch_size = 2
   shape = (config.batch_size,) + config.input_shape
   self.model = revnet.RevNet(config=config)
   self.x = tf.random_normal(shape=shape, dtype=tf.float64)
   self.t = tf.random_uniform(
       shape=[config.batch_size],
       minval=0,
       maxval=config.n_classes,
       dtype=tf.int64)
   self.config = config
Exemplo n.º 9
0
    def _benchmark_eager_train(self,
                               label,
                               make_iterator,
                               device_and_format,
                               defun=False,
                               execution_mode=None,
                               compiled=False):
        config = config_.get_hparams_imagenet_56()
        config.add_hparam("n_classes", 1000)
        config.add_hparam("dataset", "ImageNet")
        with tfe.execution_mode(execution_mode):
            device, data_format = device_and_format
            for batch_size in self._train_batch_sizes():
                (images, labels) = random_batch(batch_size, config)
                model = revnet.RevNet(config=config)
                optimizer = tf.train.GradientDescentOptimizer(0.1)
                if defun:
                    model.call = tfe.defun(model.call)

                num_burn = 3
                num_iters = 10
                with tf.device(device):
                    iterator = make_iterator((images, labels))
                    for _ in range(num_burn):
                        (images, labels) = iterator.next()
                        train_one_iter(model, images, labels, optimizer)
                    if execution_mode:
                        tfe.async_wait()
                    self._force_device_sync()
                    gc.collect()

                    start = time.time()
                    for _ in range(num_iters):
                        (images, labels) = iterator.next()
                        train_one_iter(model, images, labels, optimizer)
                    if execution_mode:
                        tfe.async_wait()
                    self._force_device_sync()
                    self._report(label, start, num_iters, device, batch_size,
                                 data_format)
Exemplo n.º 10
0
def main(_):
    """Eager execution workflow with RevNet trained on CIFAR-10."""
    if FLAGS.data_dir is None:
        raise ValueError("No supplied data directory")

    if not os.path.exists(FLAGS.data_dir):
        raise ValueError("Data directory {} does not exist".format(
            FLAGS.data_dir))

    tf.enable_eager_execution()
    config = config_.get_hparams_cifar_38()
    model = revnet.RevNet(config=config)

    ds_train = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="train",
        data_aug=True,
        batch_size=config.batch_size,
        epochs=config.epochs,
        shuffle=config.shuffle,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.prefetch)

    ds_validation = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="validation",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.prefetch)

    ds_test = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="test",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.prefetch)

    global_step = tfe.Variable(1, trainable=False)

    def learning_rate(
    ):  # TODO(lxuechen): Remove once cl/201089859 is in place
        return tf.train.piecewise_constant(global_step, config.lr_decay_steps,
                                           config.lr_list)

    optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     model=model,
                                     optimizer_step=global_step)

    if FLAGS.train_dir:
        summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
        if FLAGS.restore:
            latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
            checkpoint.restore(latest_path)

    for x, y in ds_train:
        loss = train_one_iter(model, x, y, optimizer, global_step=global_step)

        if global_step % config.log_every == 0:
            it_validation = ds_validation.make_one_shot_iterator()
            it_test = ds_test.make_one_shot_iterator()
            acc_validation = evaluate(model, it_validation)
            acc_test = evaluate(model, it_test)
            print("Iter {}, "
                  "train loss {}, "
                  "validation accuracy {}, "
                  "test accuracy {}".format(global_step.numpy(), loss,
                                            acc_validation, acc_test))

            if FLAGS.train_dir:
                with summary_writer.as_default():
                    with tf.contrib.summary.always_record_summaries():
                        tf.contrib.summary.scalar("Validation accuracy",
                                                  acc_validation)
                        tf.contrib.summary.scalar("Test accuracy", acc_test)
                        tf.contrib.summary.scalar("Training loss", loss)

        if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
            checkpoint.save(file_prefix=FLAGS.train_dir + "ckpt")
def model_fn(features, labels, mode, params):
    """Model function required by the `tf.contrib.tpu.TPUEstimator` API.

  Args:
    features: Input images
    labels: Labels of images
    mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT'
    params: A dictionary of extra parameter that might be passed

  Returns:
    An instance of `tf.contrib.tpu.TPUEstimatorSpec`
  """
    revnet_config = params["revnet_config"]
    model = revnet.RevNet(config=revnet_config)

    inputs = features
    if isinstance(inputs, dict):
        inputs = features["image"]

    if revnet_config.data_format == "channels_first":
        assert not FLAGS.transpose_input  # channels_first only for GPU
        inputs = tf.transpose(inputs, [0, 3, 1, 2])

    if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
        inputs = tf.transpose(inputs, [3, 0, 1, 2])  # HWCN to NHWC

    # Normalize the image to zero mean and unit variance.
    inputs -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=inputs.dtype)
    inputs /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=inputs.dtype)

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()
        learning_rate = tf.train.piecewise_constant(
            global_step, revnet_config.lr_decay_steps, revnet_config.lr_list)
        optimizer = tf.train.MomentumOptimizer(learning_rate,
                                               revnet_config.momentum)
        if FLAGS.use_tpu:
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

        logits, saved_hidden = model(inputs, training=True)
        grads, loss = model.compute_gradients(saved_hidden,
                                              labels,
                                              training=True)
        with tf.control_dependencies(model.get_updates_for(inputs)):
            train_op = optimizer.apply_gradients(zip(
                grads, model.trainable_variables),
                                                 global_step=global_step)
        if not FLAGS.skip_host_call:
            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            loss_t = tf.reshape(loss, [1])
            lr_t = tf.reshape(learning_rate, [1])
            host_call = (_host_call_fn, [gs_t, loss_t, lr_t])

        return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                               loss=loss,
                                               train_op=train_op,
                                               host_call=host_call)

    elif mode == tf.estimator.ModeKeys.EVAL:
        logits, _ = model(inputs, training=False)
        loss = model.compute_loss(labels=labels, logits=logits)

        return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                               loss=loss,
                                               eval_metrics=(_metric_fn,
                                                             [labels, logits]))

    else:  # Predict or export
        logits, _ = model(inputs, training=False)
        predictions = {
            "classes": tf.argmax(logits, axis=1),
            "probabilities": tf.nn.softmax(logits),
        }

        return tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                "classify": tf.estimator.export.PredictOutput(predictions)
            })
Exemplo n.º 12
0
def model_fn(features, labels, mode, params):
  """Model function required by the `tf.contrib.tpu.TPUEstimator` API.

  Args:
    features: Input images
    labels: Labels of images
    mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT'
    params: A dictionary of extra parameter that might be passed

  Returns:
    An instance of `tf.contrib.tpu.TPUEstimatorSpec`
  """

  inputs = features
  if isinstance(inputs, dict):
    inputs = features["image"]

  FLAGS = params["FLAGS"]  # pylint:disable=invalid-name,redefined-outer-name
  config = params["config"]
  model = revnet.RevNet(config=config)

  if mode == tf.estimator.ModeKeys.TRAIN:
    global_step = tf.train.get_or_create_global_step()
    learning_rate = tf.train.piecewise_constant(
        global_step, config.lr_decay_steps, config.lr_list)
    optimizer = tf.train.MomentumOptimizer(
        learning_rate, momentum=config.momentum)

    if FLAGS.use_tpu:
      optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

    # Define gradients
    grads, vars_, logits, loss = model.compute_gradients(
        inputs, labels, training=True)
    train_op = optimizer.apply_gradients(
        zip(grads, vars_), global_step=global_step)

    names = [v.name for v in model.variables]
    tf.logging.warn("{}".format(names))

    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op)

  if mode == tf.estimator.ModeKeys.EVAL:
    logits, _ = model(inputs, training=False)
    loss = model.compute_loss(labels=labels, logits=logits)

    def metric_fn(labels, logits):
      predictions = tf.argmax(logits, axis=1)
      accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)
      return {
          "accuracy": accuracy,
      }

    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))

  if mode == tf.estimator.ModeKeys.PREDICT:
    logits, _ = model(inputs, training=False)
    predictions = {
        "classes": tf.argmax(logits, axis=1),
        "probabilities": tf.nn.softmax(logits),
    }

    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode,
        predictions=predictions,
        export_outputs={
            "classify": tf.estimator.export.PredictOutput(predictions)
        })
Exemplo n.º 13
0
def main(_):
  """Eager execution workflow with RevNet trained on CIFAR-10."""
  tf.enable_eager_execution()

  config = get_config(config_name=FLAGS.config, dataset=FLAGS.dataset)
  ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets(
      data_dir=FLAGS.data_dir, config=config)
  model = revnet.RevNet(config=config)
  global_step = tf.train.get_or_create_global_step()  # Ensure correct summary
  global_step.assign(1)
  learning_rate = tf.train.piecewise_constant(
      global_step, config.lr_decay_steps, config.lr_list)
  optimizer = tf.train.MomentumOptimizer(
      learning_rate, momentum=config.momentum)
  checkpointer = tf.train.Checkpoint(
      optimizer=optimizer, model=model, optimizer_step=global_step)

  if FLAGS.use_defun:
    model.call = tfe.defun(model.call)
    model.compute_gradients = tfe.defun(model.compute_gradients)
    model.get_moving_stats = tfe.defun(model.get_moving_stats)
    model.restore_moving_stats = tfe.defun(model.restore_moving_stats)
    global apply_gradients  # pylint:disable=global-variable-undefined
    apply_gradients = tfe.defun(apply_gradients)

  if FLAGS.train_dir:
    summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
    if FLAGS.restore:
      latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
      checkpointer.restore(latest_path)
      print("Restored latest checkpoint at path:\"{}\" "
            "with global_step: {}".format(latest_path, global_step.numpy()))
      sys.stdout.flush()

  for x, y in ds_train:
    train_one_iter(model, x, y, optimizer, global_step=global_step)

    if global_step.numpy() % config.log_every == 0:
      it_test = ds_test.make_one_shot_iterator()
      acc_test, loss_test = evaluate(model, it_test)

      if FLAGS.validate:
        it_train = ds_train_one_shot.make_one_shot_iterator()
        it_validation = ds_validation.make_one_shot_iterator()
        acc_train, loss_train = evaluate(model, it_train)
        acc_validation, loss_validation = evaluate(model, it_validation)
        print("Iter {}, "
              "training set accuracy {:.4f}, loss {:.4f}; "
              "validation set accuracy {:.4f}, loss {:.4f}; "
              "test accuracy {:.4f}, loss {:.4f}".format(
                  global_step.numpy(), acc_train, loss_train, acc_validation,
                  loss_validation, acc_test, loss_test))
      else:
        print("Iter {}, test accuracy {:.4f}, loss {:.4f}".format(
            global_step.numpy(), acc_test, loss_test))
      sys.stdout.flush()

      if FLAGS.train_dir:
        with summary_writer.as_default():
          with tf.contrib.summary.always_record_summaries():
            tf.contrib.summary.scalar("Test accuracy", acc_test)
            tf.contrib.summary.scalar("Test loss", loss_test)
            if FLAGS.validate:
              tf.contrib.summary.scalar("Training accuracy", acc_train)
              tf.contrib.summary.scalar("Training loss", loss_train)
              tf.contrib.summary.scalar("Validation accuracy", acc_validation)
              tf.contrib.summary.scalar("Validation loss", loss_validation)

    if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
      saved_path = checkpointer.save(
          file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
      print("Saved checkpoint at path: \"{}\" "
            "with global_step: {}".format(saved_path, global_step.numpy()))
      sys.stdout.flush()
Exemplo n.º 14
0
def main(_):
    """Eager execution workflow with RevNet trained on CIFAR-10."""
    if FLAGS.data_dir is None:
        raise ValueError("No supplied data directory")

    if not os.path.exists(FLAGS.data_dir):
        raise ValueError("Data directory {} does not exist".format(
            FLAGS.data_dir))

    tf.enable_eager_execution()
    config = config_.get_hparams_cifar_38()

    if FLAGS.validate:
        # 40k Training set
        ds_train = cifar_input.get_ds_from_tfrecords(
            data_dir=FLAGS.data_dir,
            split="train",
            data_aug=True,
            batch_size=config.batch_size,
            epochs=config.epochs,
            shuffle=config.shuffle,
            data_format=config.data_format,
            dtype=config.dtype,
            prefetch=config.batch_size)
        # 10k Training set
        ds_validation = cifar_input.get_ds_from_tfrecords(
            data_dir=FLAGS.data_dir,
            split="validation",
            data_aug=False,
            batch_size=config.eval_batch_size,
            epochs=1,
            shuffle=False,
            data_format=config.data_format,
            dtype=config.dtype,
            prefetch=config.eval_batch_size)
    else:
        # 50k Training set
        ds_train = cifar_input.get_ds_from_tfrecords(
            data_dir=FLAGS.data_dir,
            split="train_all",
            data_aug=True,
            batch_size=config.batch_size,
            epochs=config.epochs,
            shuffle=config.shuffle,
            data_format=config.data_format,
            dtype=config.dtype,
            prefetch=config.batch_size)

    # Always compute loss and accuracy on whole training and test set
    ds_train_one_shot = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="train_all",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        shuffle=False,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.eval_batch_size)

    ds_test = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="test",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        shuffle=False,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.eval_batch_size)

    model = revnet.RevNet(config=config)
    global_step = tfe.Variable(1, trainable=False)
    learning_rate = tf.train.piecewise_constant(global_step,
                                                config.lr_decay_steps,
                                                config.lr_list)
    optimizer = tf.train.MomentumOptimizer(learning_rate,
                                           momentum=config.momentum)
    checkpointer = tf.train.Checkpoint(optimizer=optimizer,
                                       model=model,
                                       optimizer_step=global_step)

    if FLAGS.train_dir:
        summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
        if FLAGS.restore:
            latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
            checkpointer.restore(latest_path)
            print("Restored latest checkpoint at path:\"{}\" "
                  "with global_step: {}".format(latest_path,
                                                global_step.numpy()))
            sys.stdout.flush()

    warmup(model, config)

    for x, y in ds_train:
        loss = train_one_iter(model, x, y, optimizer, global_step=global_step)

        if global_step.numpy() % config.log_every == 0:
            it_train = ds_train_one_shot.make_one_shot_iterator()
            acc_train, loss_train = evaluate(model, it_train)
            it_test = ds_test.make_one_shot_iterator()
            acc_test, loss_test = evaluate(model, it_test)
            if FLAGS.validate:
                it_validation = ds_validation.make_one_shot_iterator()
                acc_validation, loss_validation = evaluate(
                    model, it_validation)
                print("Iter {}, "
                      "training set accuracy {:.4f}, loss {:.4f}; "
                      "validation set accuracy {:.4f}, loss {:4.f}"
                      "test accuracy {:.4f}, loss {:.4f}".format(
                          global_step.numpy(), acc_train, loss_train,
                          acc_validation, loss_validation, acc_test,
                          loss_test))
            else:
                print("Iter {}, "
                      "training set accuracy {:.4f}, loss {:.4f}; "
                      "test accuracy {:.4f}, loss {:.4f}".format(
                          global_step.numpy(), acc_train, loss_train, acc_test,
                          loss_test))
            sys.stdout.flush()

            if FLAGS.train_dir:
                with summary_writer.as_default():
                    with tf.contrib.summary.always_record_summaries():
                        tf.contrib.summary.scalar("Training loss", loss)
                        tf.contrib.summary.scalar("Test accuracy", acc_test)
                        if FLAGS.validate:
                            tf.contrib.summary.scalar("Validation accuracy",
                                                      acc_validation)

        if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
            saved_path = checkpointer.save(
                file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
            print("Saved checkpoint at path: \"{}\" "
                  "with global_step: {}".format(saved_path,
                                                global_step.numpy()))
            sys.stdout.flush()