Пример #1
0
 def input_fn():
     """Input function required by the `tf.estimator.Estimator` API."""
     return cifar_input.get_ds_from_tfrecords(
         data_dir=data_dir,
         split=split,
         data_aug=data_aug,
         batch_size=batch_size,
         epochs=epochs,
         shuffle=shuffle,
         prefetch=prefetch,
         data_format=config.data_format)
Пример #2
0
 def input_fn():
   """Input function required by the `tf.estimator.Estimator` API."""
   return cifar_input.get_ds_from_tfrecords(
       data_dir=data_dir,
       split=split,
       data_aug=data_aug,
       batch_size=batch_size,
       epochs=epochs,
       shuffle=shuffle,
       prefetch=prefetch,
       data_format=config.data_format)
Пример #3
0
 def input_fn(params):
   """Input function required by the `tf.contrib.tpu.TPUEstimator` API."""
   batch_size = params["batch_size"]
   return cifar_input.get_ds_from_tfrecords(
       data_dir=data_dir,
       split=split,
       data_aug=data_aug,
       batch_size=batch_size,  # per-shard batch size
       epochs=epochs,
       shuffle=shuffle,
       prefetch=batch_size,  # per-shard batch size
       data_format=config.data_format)
Пример #4
0
 def input_fn(params):
   """Input function required by the `tf.contrib.tpu.TPUEstimator` API."""
   batch_size = params["batch_size"]
   return cifar_input.get_ds_from_tfrecords(
       data_dir=data_dir,
       split=split,
       data_aug=data_aug,
       batch_size=batch_size,  # per-shard batch size
       epochs=epochs,
       shuffle=shuffle,
       prefetch=batch_size,  # per-shard batch size
       data_format=config.data_format)
Пример #5
0
def main(_):
    """Eager execution workflow with RevNet trained on CIFAR-10."""
    if FLAGS.data_dir is None:
        raise ValueError("No supplied data directory")

    if not os.path.exists(FLAGS.data_dir):
        raise ValueError("Data directory {} does not exist".format(
            FLAGS.data_dir))

    tf.enable_eager_execution()
    config = config_.get_hparams_cifar_38()
    model = revnet.RevNet(config=config)

    ds_train = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="train",
        data_aug=True,
        batch_size=config.batch_size,
        epochs=config.epochs,
        shuffle=config.shuffle,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.prefetch)

    ds_validation = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="validation",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.prefetch)

    ds_test = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="test",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.prefetch)

    global_step = tfe.Variable(1, trainable=False)

    def learning_rate(
    ):  # TODO(lxuechen): Remove once cl/201089859 is in place
        return tf.train.piecewise_constant(global_step, config.lr_decay_steps,
                                           config.lr_list)

    optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     model=model,
                                     optimizer_step=global_step)

    if FLAGS.train_dir:
        summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
        if FLAGS.restore:
            latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
            checkpoint.restore(latest_path)

    for x, y in ds_train:
        loss = train_one_iter(model, x, y, optimizer, global_step=global_step)

        if global_step % config.log_every == 0:
            it_validation = ds_validation.make_one_shot_iterator()
            it_test = ds_test.make_one_shot_iterator()
            acc_validation = evaluate(model, it_validation)
            acc_test = evaluate(model, it_test)
            print("Iter {}, "
                  "train loss {}, "
                  "validation accuracy {}, "
                  "test accuracy {}".format(global_step.numpy(), loss,
                                            acc_validation, acc_test))

            if FLAGS.train_dir:
                with summary_writer.as_default():
                    with tf.contrib.summary.always_record_summaries():
                        tf.contrib.summary.scalar("Validation accuracy",
                                                  acc_validation)
                        tf.contrib.summary.scalar("Test accuracy", acc_test)
                        tf.contrib.summary.scalar("Training loss", loss)

        if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
            checkpoint.save(file_prefix=FLAGS.train_dir + "ckpt")
Пример #6
0
def get_datasets(data_dir, config):
  """Return dataset."""
  if data_dir is None:
    raise ValueError("No supplied data directory")
  if not os.path.exists(data_dir):
    raise ValueError("Data directory {} does not exist".format(data_dir))
  if config.dataset not in ["cifar-10", "cifar-100"]:
    raise ValueError("Unknown dataset {}".format(config.dataset))

  print("Training on {} dataset.".format(config.dataset))
  sys.stdout.flush()
  data_dir = os.path.join(data_dir, config.dataset)
  if FLAGS.validate:
    # 40k Training set
    ds_train = cifar_input.get_ds_from_tfrecords(
        data_dir=data_dir,
        split="train",
        data_aug=True,
        batch_size=config.batch_size,
        epochs=config.epochs,
        shuffle=config.shuffle,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.batch_size)
    # 10k Training set
    ds_validation = cifar_input.get_ds_from_tfrecords(
        data_dir=data_dir,
        split="validation",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        shuffle=False,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.eval_batch_size)
  else:
    # 50k Training set
    ds_train = cifar_input.get_ds_from_tfrecords(
        data_dir=data_dir,
        split="train_all",
        data_aug=True,
        batch_size=config.batch_size,
        epochs=config.epochs,
        shuffle=config.shuffle,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.batch_size)
    ds_validation = None

  # Always compute loss and accuracy on whole test set
  ds_train_one_shot = cifar_input.get_ds_from_tfrecords(
      data_dir=data_dir,
      split="train_all",
      data_aug=False,
      batch_size=config.eval_batch_size,
      epochs=1,
      shuffle=False,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.eval_batch_size)

  ds_test = cifar_input.get_ds_from_tfrecords(
      data_dir=data_dir,
      split="test",
      data_aug=False,
      batch_size=config.eval_batch_size,
      epochs=1,
      shuffle=False,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.eval_batch_size)

  return ds_train, ds_train_one_shot, ds_validation, ds_test
Пример #7
0
def get_datasets(data_dir, config):
  """Return dataset."""
  if data_dir is None:
    raise ValueError("No supplied data directory")
  if not os.path.exists(data_dir):
    raise ValueError("Data directory {} does not exist".format(data_dir))
  if config.dataset not in ["cifar-10", "cifar-100"]:
    raise ValueError("Unknown dataset {}".format(config.dataset))

  print("Training on {} dataset.".format(config.dataset))
  sys.stdout.flush()
  data_dir = os.path.join(data_dir, config.dataset)
  if FLAGS.validate:
    # 40k Training set
    ds_train = cifar_input.get_ds_from_tfrecords(
        data_dir=data_dir,
        split="train",
        data_aug=True,
        batch_size=config.batch_size,
        epochs=config.epochs,
        shuffle=config.shuffle,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.batch_size)
    # 10k Training set
    ds_validation = cifar_input.get_ds_from_tfrecords(
        data_dir=data_dir,
        split="validation",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        shuffle=False,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.eval_batch_size)
  else:
    # 50k Training set
    ds_train = cifar_input.get_ds_from_tfrecords(
        data_dir=data_dir,
        split="train_all",
        data_aug=True,
        batch_size=config.batch_size,
        epochs=config.epochs,
        shuffle=config.shuffle,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.batch_size)
    ds_validation = None

  # Always compute loss and accuracy on whole test set
  ds_train_one_shot = cifar_input.get_ds_from_tfrecords(
      data_dir=data_dir,
      split="train_all",
      data_aug=False,
      batch_size=config.eval_batch_size,
      epochs=1,
      shuffle=False,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.eval_batch_size)

  ds_test = cifar_input.get_ds_from_tfrecords(
      data_dir=data_dir,
      split="test",
      data_aug=False,
      batch_size=config.eval_batch_size,
      epochs=1,
      shuffle=False,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.eval_batch_size)

  return ds_train, ds_train_one_shot, ds_validation, ds_test
Пример #8
0
def main(_):
    """Eager execution workflow with RevNet trained on CIFAR-10."""
    if FLAGS.data_dir is None:
        raise ValueError("No supplied data directory")

    if not os.path.exists(FLAGS.data_dir):
        raise ValueError("Data directory {} does not exist".format(
            FLAGS.data_dir))

    tf.enable_eager_execution()
    config = config_.get_hparams_cifar_38()

    if FLAGS.validate:
        # 40k Training set
        ds_train = cifar_input.get_ds_from_tfrecords(
            data_dir=FLAGS.data_dir,
            split="train",
            data_aug=True,
            batch_size=config.batch_size,
            epochs=config.epochs,
            shuffle=config.shuffle,
            data_format=config.data_format,
            dtype=config.dtype,
            prefetch=config.batch_size)
        # 10k Training set
        ds_validation = cifar_input.get_ds_from_tfrecords(
            data_dir=FLAGS.data_dir,
            split="validation",
            data_aug=False,
            batch_size=config.eval_batch_size,
            epochs=1,
            shuffle=False,
            data_format=config.data_format,
            dtype=config.dtype,
            prefetch=config.eval_batch_size)
    else:
        # 50k Training set
        ds_train = cifar_input.get_ds_from_tfrecords(
            data_dir=FLAGS.data_dir,
            split="train_all",
            data_aug=True,
            batch_size=config.batch_size,
            epochs=config.epochs,
            shuffle=config.shuffle,
            data_format=config.data_format,
            dtype=config.dtype,
            prefetch=config.batch_size)

    # Always compute loss and accuracy on whole training and test set
    ds_train_one_shot = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="train_all",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        shuffle=False,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.eval_batch_size)

    ds_test = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="test",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        shuffle=False,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.eval_batch_size)

    model = revnet.RevNet(config=config)
    global_step = tfe.Variable(1, trainable=False)
    learning_rate = tf.train.piecewise_constant(global_step,
                                                config.lr_decay_steps,
                                                config.lr_list)
    optimizer = tf.train.MomentumOptimizer(learning_rate,
                                           momentum=config.momentum)
    checkpointer = tf.train.Checkpoint(optimizer=optimizer,
                                       model=model,
                                       optimizer_step=global_step)

    if FLAGS.train_dir:
        summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
        if FLAGS.restore:
            latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
            checkpointer.restore(latest_path)
            print("Restored latest checkpoint at path:\"{}\" "
                  "with global_step: {}".format(latest_path,
                                                global_step.numpy()))
            sys.stdout.flush()

    warmup(model, config)

    for x, y in ds_train:
        loss = train_one_iter(model, x, y, optimizer, global_step=global_step)

        if global_step.numpy() % config.log_every == 0:
            it_train = ds_train_one_shot.make_one_shot_iterator()
            acc_train, loss_train = evaluate(model, it_train)
            it_test = ds_test.make_one_shot_iterator()
            acc_test, loss_test = evaluate(model, it_test)
            if FLAGS.validate:
                it_validation = ds_validation.make_one_shot_iterator()
                acc_validation, loss_validation = evaluate(
                    model, it_validation)
                print("Iter {}, "
                      "training set accuracy {:.4f}, loss {:.4f}; "
                      "validation set accuracy {:.4f}, loss {:4.f}"
                      "test accuracy {:.4f}, loss {:.4f}".format(
                          global_step.numpy(), acc_train, loss_train,
                          acc_validation, loss_validation, acc_test,
                          loss_test))
            else:
                print("Iter {}, "
                      "training set accuracy {:.4f}, loss {:.4f}; "
                      "test accuracy {:.4f}, loss {:.4f}".format(
                          global_step.numpy(), acc_train, loss_train, acc_test,
                          loss_test))
            sys.stdout.flush()

            if FLAGS.train_dir:
                with summary_writer.as_default():
                    with tf.contrib.summary.always_record_summaries():
                        tf.contrib.summary.scalar("Training loss", loss)
                        tf.contrib.summary.scalar("Test accuracy", acc_test)
                        if FLAGS.validate:
                            tf.contrib.summary.scalar("Validation accuracy",
                                                      acc_validation)

        if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
            saved_path = checkpointer.save(
                file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
            print("Saved checkpoint at path: \"{}\" "
                  "with global_step: {}".format(saved_path,
                                                global_step.numpy()))
            sys.stdout.flush()
Пример #9
0
def main(_):
  """Eager execution workflow with RevNet trained on CIFAR-10."""
  if FLAGS.data_dir is None:
    raise ValueError("No supplied data directory")

  if not os.path.exists(FLAGS.data_dir):
    raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir))

  tf.enable_eager_execution()
  config = config_.get_hparams_cifar_38()

  if FLAGS.validate:
    # 40k Training set
    ds_train = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="train",
        data_aug=True,
        batch_size=config.batch_size,
        epochs=config.epochs,
        shuffle=config.shuffle,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.batch_size)
    # 10k Training set
    ds_validation = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="validation",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        shuffle=False,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.eval_batch_size)
  else:
    # 50k Training set
    ds_train = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="train_all",
        data_aug=True,
        batch_size=config.batch_size,
        epochs=config.epochs,
        shuffle=config.shuffle,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.batch_size)

  # Always compute loss and accuracy on whole training and test set
  ds_train_one_shot = cifar_input.get_ds_from_tfrecords(
      data_dir=FLAGS.data_dir,
      split="train_all",
      data_aug=False,
      batch_size=config.eval_batch_size,
      epochs=1,
      shuffle=False,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.eval_batch_size)

  ds_test = cifar_input.get_ds_from_tfrecords(
      data_dir=FLAGS.data_dir,
      split="test",
      data_aug=False,
      batch_size=config.eval_batch_size,
      epochs=1,
      shuffle=False,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.eval_batch_size)

  model = revnet.RevNet(config=config)
  global_step = tfe.Variable(1, trainable=False)
  learning_rate = tf.train.piecewise_constant(
      global_step, config.lr_decay_steps, config.lr_list)
  optimizer = tf.train.MomentumOptimizer(
      learning_rate, momentum=config.momentum)
  checkpointer = tf.train.Checkpoint(
      optimizer=optimizer, model=model, optimizer_step=global_step)

  if FLAGS.train_dir:
    summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
    if FLAGS.restore:
      latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
      checkpointer.restore(latest_path)
      print("Restored latest checkpoint at path:\"{}\" "
            "with global_step: {}".format(latest_path, global_step.numpy()))
      sys.stdout.flush()

  warmup(model, config)

  for x, y in ds_train:
    loss = train_one_iter(model, x, y, optimizer, global_step=global_step)

    if global_step.numpy() % config.log_every == 0:
      it_train = ds_train_one_shot.make_one_shot_iterator()
      acc_train, loss_train = evaluate(model, it_train)
      it_test = ds_test.make_one_shot_iterator()
      acc_test, loss_test = evaluate(model, it_test)
      if FLAGS.validate:
        it_validation = ds_validation.make_one_shot_iterator()
        acc_validation, loss_validation = evaluate(model, it_validation)
        print("Iter {}, "
              "training set accuracy {:.4f}, loss {:.4f}; "
              "validation set accuracy {:.4f}, loss {:4.f}"
              "test accuracy {:.4f}, loss {:.4f}".format(
                  global_step.numpy(), acc_train, loss_train, acc_validation,
                  loss_validation, acc_test, loss_test))
      else:
        print("Iter {}, "
              "training set accuracy {:.4f}, loss {:.4f}; "
              "test accuracy {:.4f}, loss {:.4f}".format(
                  global_step.numpy(), acc_train, loss_train, acc_test,
                  loss_test))
      sys.stdout.flush()

      if FLAGS.train_dir:
        with summary_writer.as_default():
          with tf.contrib.summary.always_record_summaries():
            tf.contrib.summary.scalar("Training loss", loss)
            tf.contrib.summary.scalar("Test accuracy", acc_test)
            if FLAGS.validate:
              tf.contrib.summary.scalar("Validation accuracy", acc_validation)

    if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
      saved_path = checkpointer.save(
          file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
      print("Saved checkpoint at path: \"{}\" "
            "with global_step: {}".format(saved_path, global_step.numpy()))
      sys.stdout.flush()
Пример #10
0
def main(_):
  """Eager execution workflow with RevNet trained on CIFAR-10."""
  if FLAGS.data_dir is None:
    raise ValueError("No supplied data directory")

  if not os.path.exists(FLAGS.data_dir):
    raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir))

  tf.enable_eager_execution()
  config = config_.get_hparams_cifar_38()
  model = revnet.RevNet(config=config)

  ds_train = cifar_input.get_ds_from_tfrecords(
      data_dir=FLAGS.data_dir,
      split="train",
      data_aug=True,
      batch_size=config.batch_size,
      epochs=config.epochs,
      shuffle=config.shuffle,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.prefetch)

  ds_validation = cifar_input.get_ds_from_tfrecords(
      data_dir=FLAGS.data_dir,
      split="validation",
      data_aug=False,
      batch_size=config.eval_batch_size,
      epochs=1,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.prefetch)

  ds_test = cifar_input.get_ds_from_tfrecords(
      data_dir=FLAGS.data_dir,
      split="test",
      data_aug=False,
      batch_size=config.eval_batch_size,
      epochs=1,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.prefetch)

  global_step = tfe.Variable(1, trainable=False)

  def learning_rate():  # TODO(lxuechen): Remove once cl/201089859 is in place
    return tf.train.piecewise_constant(global_step, config.lr_decay_steps,
                                       config.lr_list)

  optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
  checkpoint = tf.train.Checkpoint(
      optimizer=optimizer, model=model, optimizer_step=global_step)

  if FLAGS.train_dir:
    summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
    if FLAGS.restore:
      latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
      checkpoint.restore(latest_path)

  for x, y in ds_train:
    loss = train_one_iter(model, x, y, optimizer, global_step=global_step)

    if global_step % config.log_every == 0:
      it_validation = ds_validation.make_one_shot_iterator()
      it_test = ds_test.make_one_shot_iterator()
      acc_validation = evaluate(model, it_validation)
      acc_test = evaluate(model, it_test)
      print("Iter {}, "
            "train loss {}, "
            "validation accuracy {}, "
            "test accuracy {}".format(global_step.numpy(), loss, acc_validation,
                                      acc_test))

      if FLAGS.train_dir:
        with summary_writer.as_default():
          with tf.contrib.summary.always_record_summaries():
            tf.contrib.summary.scalar("Validation accuracy", acc_validation)
            tf.contrib.summary.scalar("Test accuracy", acc_test)
            tf.contrib.summary.scalar("Training loss", loss)

    if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
      checkpoint.save(file_prefix=FLAGS.train_dir + "ckpt")