Пример #1
0
    def test_training_graph(self):
        """Test model training in graph mode."""
        with tf.Graph().as_default():
            config = config_.get_hparams_cifar_38()
            x = tf.random_normal(shape=(self.config.batch_size, ) +
                                 self.config.input_shape)
            t = tf.random_uniform(shape=(self.config.batch_size, ),
                                  minval=0,
                                  maxval=self.config.n_classes,
                                  dtype=tf.int32)
            global_step = tf.Variable(0., trainable=False)
            model = revnet.RevNet(config=config)
            model(x)
            updates = model.get_updates_for(x)

            x_ = tf.identity(x)
            grads_all, vars_all, _ = model.compute_gradients(x_,
                                                             t,
                                                             training=True)
            optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
            with tf.control_dependencies(updates):
                train_op = optimizer.apply_gradients(zip(grads_all, vars_all),
                                                     global_step=global_step)

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                for _ in range(1):
                    sess.run(train_op)
Пример #2
0
  def test_training_graph(self):
    """Test model training in graph mode."""
    with tf.Graph().as_default():
      config = config_.get_hparams_cifar_38()
      config.add_hparam("n_classes", 10)
      config.add_hparam("dataset", "cifar-10")

      x = tf.random_normal(
          shape=(self.config.batch_size,) + self.config.input_shape)
      t = tf.random_uniform(
          shape=(self.config.batch_size,),
          minval=0,
          maxval=self.config.n_classes,
          dtype=tf.int32)
      global_step = tf.Variable(0., trainable=False)
      model = revnet.RevNet(config=config)
      grads_all, vars_all, _, _ = model.compute_gradients(x, t, training=True)
      optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
      train_op = optimizer.apply_gradients(
          zip(grads_all, vars_all), global_step=global_step)

      with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for _ in range(1):
          sess.run(train_op)
Пример #3
0
  def test_training_graph(self):
    """Test model training in graph mode."""
    with tf.Graph().as_default():
      config = config_.get_hparams_cifar_38()
      config.add_hparam("n_classes", 10)
      config.add_hparam("dataset", "cifar-10")

      x = tf.random_normal(
          shape=(self.config.batch_size,) + self.config.input_shape)
      t = tf.random_uniform(
          shape=(self.config.batch_size,),
          minval=0,
          maxval=self.config.n_classes,
          dtype=tf.int32)
      global_step = tf.Variable(0., trainable=False)
      model = revnet.RevNet(config=config)
      _, saved_hidden = model(x)
      grads, _ = model.compute_gradients(saved_hidden=saved_hidden, labels=t)
      optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
      train_op = optimizer.apply_gradients(
          zip(grads, model.trainable_variables), global_step=global_step)

      with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for _ in range(1):
          sess.run(train_op)
Пример #4
0
 def setUp(self):
     super(RevNetTest, self).setUp()
     config = config_.get_hparams_cifar_38()
     # Reconstruction could cause numerical error, use double precision for tests
     config.dtype = tf.float64
     config.fused = False  # Fused batch norm does not support tf.float64
     shape = (config.batch_size, ) + config.input_shape
     self.model = revnet.RevNet(config=config)
     self.x = tf.random_normal(shape=shape, dtype=tf.float64)
     self.t = tf.random_uniform(shape=[config.batch_size],
                                minval=0,
                                maxval=config.n_classes,
                                dtype=tf.int64)
     self.config = config
Пример #5
0
def get_config():
  """Return configuration."""
  print("Config: {}".format(FLAGS.config))
  sys.stdout.flush()
  config = {
      "revnet-38": config_.get_hparams_cifar_38(),
      "revnet-110": config_.get_hparams_cifar_110(),
      "revnet-164": config_.get_hparams_cifar_164(),
  }[FLAGS.config]

  if FLAGS.dataset == "cifar-100":
    config.n_classes = 100

  return config
Пример #6
0
def get_config():
    """Return configuration."""
    print("Config: {}".format(FLAGS.config))
    sys.stdout.flush()
    config = {
        "revnet-38": config_.get_hparams_cifar_38(),
        "revnet-110": config_.get_hparams_cifar_110(),
        "revnet-164": config_.get_hparams_cifar_164(),
    }[FLAGS.config]

    if FLAGS.dataset == "cifar-100":
        config.n_classes = 100

    return config
Пример #7
0
 def setUp(self):
   super(RevNetTest, self).setUp()
   config = config_.get_hparams_cifar_38()
   # Reconstruction could cause numerical error, use double precision for tests
   config.dtype = tf.float64
   config.fused = False  # Fused batch norm does not support tf.float64
   shape = (config.batch_size,) + config.input_shape
   self.model = revnet.RevNet(config=config)
   self.x = tf.random_normal(shape=shape, dtype=tf.float64)
   self.t = tf.random_uniform(
       shape=[config.batch_size],
       minval=0,
       maxval=config.n_classes,
       dtype=tf.int64)
   self.config = config
Пример #8
0
def get_config(config_name="revnet-38", dataset="cifar-10"):
  """Return configuration."""
  print("Config: {}".format(config_name))
  sys.stdout.flush()
  config = {
      "revnet-38": config_.get_hparams_cifar_38(),
      "revnet-110": config_.get_hparams_cifar_110(),
      "revnet-164": config_.get_hparams_cifar_164(),
  }[config_name]

  if dataset == "cifar-10":
    config.add_hparam("n_classes", 10)
    config.add_hparam("dataset", "cifar-10")
  else:
    config.add_hparam("n_classes", 100)
    config.add_hparam("dataset", "cifar-100")

  return config
Пример #9
0
def get_config(config_name="revnet-38", dataset="cifar-10"):
  """Return configuration."""
  print("Config: {}".format(config_name))
  sys.stdout.flush()
  config = {
      "revnet-38": config_.get_hparams_cifar_38(),
      "revnet-110": config_.get_hparams_cifar_110(),
      "revnet-164": config_.get_hparams_cifar_164(),
  }[config_name]

  if dataset == "cifar-10":
    config.add_hparam("n_classes", 10)
    config.add_hparam("dataset", "cifar-10")
  else:
    config.add_hparam("n_classes", 100)
    config.add_hparam("dataset", "cifar-100")

  return config
Пример #10
0
 def setUp(self):
   super(RevNetTest, self).setUp()
   config = config_.get_hparams_cifar_38()
   config.add_hparam("n_classes", 10)
   config.add_hparam("dataset", "cifar-10")
   # Reconstruction could cause numerical error, use double precision for tests
   config.dtype = tf.float64
   config.fused = False  # Fused batch norm does not support tf.float64
   # Reduce the batch size for tests because the OSS version runs
   # in constrained GPU environment with 1-2GB of memory.
   config.batch_size = 2
   shape = (config.batch_size,) + config.input_shape
   self.model = revnet.RevNet(config=config)
   self.x = tf.random_normal(shape=shape, dtype=tf.float64)
   self.t = tf.random_uniform(
       shape=[config.batch_size],
       minval=0,
       maxval=config.n_classes,
       dtype=tf.int64)
   self.config = config
Пример #11
0
def main(_):
    """Eager execution workflow with RevNet trained on CIFAR-10."""
    if FLAGS.data_dir is None:
        raise ValueError("No supplied data directory")

    if not os.path.exists(FLAGS.data_dir):
        raise ValueError("Data directory {} does not exist".format(
            FLAGS.data_dir))

    tf.enable_eager_execution()
    config = config_.get_hparams_cifar_38()
    model = revnet.RevNet(config=config)

    ds_train = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="train",
        data_aug=True,
        batch_size=config.batch_size,
        epochs=config.epochs,
        shuffle=config.shuffle,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.prefetch)

    ds_validation = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="validation",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.prefetch)

    ds_test = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="test",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.prefetch)

    global_step = tfe.Variable(1, trainable=False)

    def learning_rate(
    ):  # TODO(lxuechen): Remove once cl/201089859 is in place
        return tf.train.piecewise_constant(global_step, config.lr_decay_steps,
                                           config.lr_list)

    optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     model=model,
                                     optimizer_step=global_step)

    if FLAGS.train_dir:
        summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
        if FLAGS.restore:
            latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
            checkpoint.restore(latest_path)

    for x, y in ds_train:
        loss = train_one_iter(model, x, y, optimizer, global_step=global_step)

        if global_step % config.log_every == 0:
            it_validation = ds_validation.make_one_shot_iterator()
            it_test = ds_test.make_one_shot_iterator()
            acc_validation = evaluate(model, it_validation)
            acc_test = evaluate(model, it_test)
            print("Iter {}, "
                  "train loss {}, "
                  "validation accuracy {}, "
                  "test accuracy {}".format(global_step.numpy(), loss,
                                            acc_validation, acc_test))

            if FLAGS.train_dir:
                with summary_writer.as_default():
                    with tf.contrib.summary.always_record_summaries():
                        tf.contrib.summary.scalar("Validation accuracy",
                                                  acc_validation)
                        tf.contrib.summary.scalar("Test accuracy", acc_test)
                        tf.contrib.summary.scalar("Training loss", loss)

        if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
            checkpoint.save(file_prefix=FLAGS.train_dir + "ckpt")
Пример #12
0
def main(_):
    """Eager execution workflow with RevNet trained on CIFAR-10."""
    if FLAGS.data_dir is None:
        raise ValueError("No supplied data directory")

    if not os.path.exists(FLAGS.data_dir):
        raise ValueError("Data directory {} does not exist".format(
            FLAGS.data_dir))

    tf.enable_eager_execution()
    config = config_.get_hparams_cifar_38()

    if FLAGS.validate:
        # 40k Training set
        ds_train = cifar_input.get_ds_from_tfrecords(
            data_dir=FLAGS.data_dir,
            split="train",
            data_aug=True,
            batch_size=config.batch_size,
            epochs=config.epochs,
            shuffle=config.shuffle,
            data_format=config.data_format,
            dtype=config.dtype,
            prefetch=config.batch_size)
        # 10k Training set
        ds_validation = cifar_input.get_ds_from_tfrecords(
            data_dir=FLAGS.data_dir,
            split="validation",
            data_aug=False,
            batch_size=config.eval_batch_size,
            epochs=1,
            shuffle=False,
            data_format=config.data_format,
            dtype=config.dtype,
            prefetch=config.eval_batch_size)
    else:
        # 50k Training set
        ds_train = cifar_input.get_ds_from_tfrecords(
            data_dir=FLAGS.data_dir,
            split="train_all",
            data_aug=True,
            batch_size=config.batch_size,
            epochs=config.epochs,
            shuffle=config.shuffle,
            data_format=config.data_format,
            dtype=config.dtype,
            prefetch=config.batch_size)

    # Always compute loss and accuracy on whole training and test set
    ds_train_one_shot = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="train_all",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        shuffle=False,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.eval_batch_size)

    ds_test = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="test",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        shuffle=False,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.eval_batch_size)

    model = revnet.RevNet(config=config)
    global_step = tfe.Variable(1, trainable=False)
    learning_rate = tf.train.piecewise_constant(global_step,
                                                config.lr_decay_steps,
                                                config.lr_list)
    optimizer = tf.train.MomentumOptimizer(learning_rate,
                                           momentum=config.momentum)
    checkpointer = tf.train.Checkpoint(optimizer=optimizer,
                                       model=model,
                                       optimizer_step=global_step)

    if FLAGS.train_dir:
        summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
        if FLAGS.restore:
            latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
            checkpointer.restore(latest_path)
            print("Restored latest checkpoint at path:\"{}\" "
                  "with global_step: {}".format(latest_path,
                                                global_step.numpy()))
            sys.stdout.flush()

    warmup(model, config)

    for x, y in ds_train:
        loss = train_one_iter(model, x, y, optimizer, global_step=global_step)

        if global_step.numpy() % config.log_every == 0:
            it_train = ds_train_one_shot.make_one_shot_iterator()
            acc_train, loss_train = evaluate(model, it_train)
            it_test = ds_test.make_one_shot_iterator()
            acc_test, loss_test = evaluate(model, it_test)
            if FLAGS.validate:
                it_validation = ds_validation.make_one_shot_iterator()
                acc_validation, loss_validation = evaluate(
                    model, it_validation)
                print("Iter {}, "
                      "training set accuracy {:.4f}, loss {:.4f}; "
                      "validation set accuracy {:.4f}, loss {:4.f}"
                      "test accuracy {:.4f}, loss {:.4f}".format(
                          global_step.numpy(), acc_train, loss_train,
                          acc_validation, loss_validation, acc_test,
                          loss_test))
            else:
                print("Iter {}, "
                      "training set accuracy {:.4f}, loss {:.4f}; "
                      "test accuracy {:.4f}, loss {:.4f}".format(
                          global_step.numpy(), acc_train, loss_train, acc_test,
                          loss_test))
            sys.stdout.flush()

            if FLAGS.train_dir:
                with summary_writer.as_default():
                    with tf.contrib.summary.always_record_summaries():
                        tf.contrib.summary.scalar("Training loss", loss)
                        tf.contrib.summary.scalar("Test accuracy", acc_test)
                        if FLAGS.validate:
                            tf.contrib.summary.scalar("Validation accuracy",
                                                      acc_validation)

        if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
            saved_path = checkpointer.save(
                file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
            print("Saved checkpoint at path: \"{}\" "
                  "with global_step: {}".format(saved_path,
                                                global_step.numpy()))
            sys.stdout.flush()
Пример #13
0
def main(_):
  """Eager execution workflow with RevNet trained on CIFAR-10."""
  if FLAGS.data_dir is None:
    raise ValueError("No supplied data directory")

  if not os.path.exists(FLAGS.data_dir):
    raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir))

  tf.enable_eager_execution()
  config = config_.get_hparams_cifar_38()

  if FLAGS.validate:
    # 40k Training set
    ds_train = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="train",
        data_aug=True,
        batch_size=config.batch_size,
        epochs=config.epochs,
        shuffle=config.shuffle,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.batch_size)
    # 10k Training set
    ds_validation = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="validation",
        data_aug=False,
        batch_size=config.eval_batch_size,
        epochs=1,
        shuffle=False,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.eval_batch_size)
  else:
    # 50k Training set
    ds_train = cifar_input.get_ds_from_tfrecords(
        data_dir=FLAGS.data_dir,
        split="train_all",
        data_aug=True,
        batch_size=config.batch_size,
        epochs=config.epochs,
        shuffle=config.shuffle,
        data_format=config.data_format,
        dtype=config.dtype,
        prefetch=config.batch_size)

  # Always compute loss and accuracy on whole training and test set
  ds_train_one_shot = cifar_input.get_ds_from_tfrecords(
      data_dir=FLAGS.data_dir,
      split="train_all",
      data_aug=False,
      batch_size=config.eval_batch_size,
      epochs=1,
      shuffle=False,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.eval_batch_size)

  ds_test = cifar_input.get_ds_from_tfrecords(
      data_dir=FLAGS.data_dir,
      split="test",
      data_aug=False,
      batch_size=config.eval_batch_size,
      epochs=1,
      shuffle=False,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.eval_batch_size)

  model = revnet.RevNet(config=config)
  global_step = tfe.Variable(1, trainable=False)
  learning_rate = tf.train.piecewise_constant(
      global_step, config.lr_decay_steps, config.lr_list)
  optimizer = tf.train.MomentumOptimizer(
      learning_rate, momentum=config.momentum)
  checkpointer = tf.train.Checkpoint(
      optimizer=optimizer, model=model, optimizer_step=global_step)

  if FLAGS.train_dir:
    summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
    if FLAGS.restore:
      latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
      checkpointer.restore(latest_path)
      print("Restored latest checkpoint at path:\"{}\" "
            "with global_step: {}".format(latest_path, global_step.numpy()))
      sys.stdout.flush()

  warmup(model, config)

  for x, y in ds_train:
    loss = train_one_iter(model, x, y, optimizer, global_step=global_step)

    if global_step.numpy() % config.log_every == 0:
      it_train = ds_train_one_shot.make_one_shot_iterator()
      acc_train, loss_train = evaluate(model, it_train)
      it_test = ds_test.make_one_shot_iterator()
      acc_test, loss_test = evaluate(model, it_test)
      if FLAGS.validate:
        it_validation = ds_validation.make_one_shot_iterator()
        acc_validation, loss_validation = evaluate(model, it_validation)
        print("Iter {}, "
              "training set accuracy {:.4f}, loss {:.4f}; "
              "validation set accuracy {:.4f}, loss {:4.f}"
              "test accuracy {:.4f}, loss {:.4f}".format(
                  global_step.numpy(), acc_train, loss_train, acc_validation,
                  loss_validation, acc_test, loss_test))
      else:
        print("Iter {}, "
              "training set accuracy {:.4f}, loss {:.4f}; "
              "test accuracy {:.4f}, loss {:.4f}".format(
                  global_step.numpy(), acc_train, loss_train, acc_test,
                  loss_test))
      sys.stdout.flush()

      if FLAGS.train_dir:
        with summary_writer.as_default():
          with tf.contrib.summary.always_record_summaries():
            tf.contrib.summary.scalar("Training loss", loss)
            tf.contrib.summary.scalar("Test accuracy", acc_test)
            if FLAGS.validate:
              tf.contrib.summary.scalar("Validation accuracy", acc_validation)

    if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
      saved_path = checkpointer.save(
          file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
      print("Saved checkpoint at path: \"{}\" "
            "with global_step: {}".format(saved_path, global_step.numpy()))
      sys.stdout.flush()
Пример #14
0
def main(_):
  """Eager execution workflow with RevNet trained on CIFAR-10."""
  if FLAGS.data_dir is None:
    raise ValueError("No supplied data directory")

  if not os.path.exists(FLAGS.data_dir):
    raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir))

  tf.enable_eager_execution()
  config = config_.get_hparams_cifar_38()
  model = revnet.RevNet(config=config)

  ds_train = cifar_input.get_ds_from_tfrecords(
      data_dir=FLAGS.data_dir,
      split="train",
      data_aug=True,
      batch_size=config.batch_size,
      epochs=config.epochs,
      shuffle=config.shuffle,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.prefetch)

  ds_validation = cifar_input.get_ds_from_tfrecords(
      data_dir=FLAGS.data_dir,
      split="validation",
      data_aug=False,
      batch_size=config.eval_batch_size,
      epochs=1,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.prefetch)

  ds_test = cifar_input.get_ds_from_tfrecords(
      data_dir=FLAGS.data_dir,
      split="test",
      data_aug=False,
      batch_size=config.eval_batch_size,
      epochs=1,
      data_format=config.data_format,
      dtype=config.dtype,
      prefetch=config.prefetch)

  global_step = tfe.Variable(1, trainable=False)

  def learning_rate():  # TODO(lxuechen): Remove once cl/201089859 is in place
    return tf.train.piecewise_constant(global_step, config.lr_decay_steps,
                                       config.lr_list)

  optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
  checkpoint = tf.train.Checkpoint(
      optimizer=optimizer, model=model, optimizer_step=global_step)

  if FLAGS.train_dir:
    summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
    if FLAGS.restore:
      latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
      checkpoint.restore(latest_path)

  for x, y in ds_train:
    loss = train_one_iter(model, x, y, optimizer, global_step=global_step)

    if global_step % config.log_every == 0:
      it_validation = ds_validation.make_one_shot_iterator()
      it_test = ds_test.make_one_shot_iterator()
      acc_validation = evaluate(model, it_validation)
      acc_test = evaluate(model, it_test)
      print("Iter {}, "
            "train loss {}, "
            "validation accuracy {}, "
            "test accuracy {}".format(global_step.numpy(), loss, acc_validation,
                                      acc_test))

      if FLAGS.train_dir:
        with summary_writer.as_default():
          with tf.contrib.summary.always_record_summaries():
            tf.contrib.summary.scalar("Validation accuracy", acc_validation)
            tf.contrib.summary.scalar("Test accuracy", acc_test)
            tf.contrib.summary.scalar("Training loss", loss)

    if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
      checkpoint.save(file_prefix=FLAGS.train_dir + "ckpt")