def _benchmark_eager_apply(self, label, device_and_format, defun=False, execution_mode=None, compiled=False): config = config_.get_hparams_imagenet_56() with tfe.execution_mode(execution_mode): device, data_format = device_and_format model = revnet.RevNet(config=config) if defun: model.call = tfe.defun(model.call, compiled=compiled) batch_size = 64 num_burn = 5 num_iters = 10 with tf.device(device): images, _ = random_batch(batch_size, config) for _ in range(num_burn): model(images, training=False) if execution_mode: tfe.async_wait() gc.collect() start = time.time() for _ in range(num_iters): model(images, training=False) if execution_mode: tfe.async_wait() self._report(label, start, num_iters, device, batch_size, data_format)
def test_training_graph(self): """Test model training in graph mode.""" with tf.Graph().as_default(): x = tf.random_normal( shape=(self.config.batch_size,) + self.config.input_shape) t = tf.random_uniform( shape=(self.config.batch_size,), minval=0, maxval=self.config.n_classes, dtype=tf.int32) global_step = tfe.Variable(0., trainable=False) model = revnet.RevNet(config=self.config) grads_all, vars_all, _ = model.compute_gradients(x, t, training=True) optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) updates = model.get_updates_for(x) self.assertEqual(len(updates), 192) with tf.control_dependencies(model.get_updates_for(x)): train_op = optimizer.apply_gradients( zip(grads_all, vars_all), global_step=global_step) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(1): sess.run(train_op)
def model_fn(features, labels, mode, params): """Function specifying the model that is required by the `tf.estimator` API. Args: features: Input images labels: Labels of images mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT' params: A dictionary of extra parameter that might be passed Returns: An instance of `tf.estimator.EstimatorSpec` """ inputs = features if isinstance(inputs, dict): inputs = features["image"] config = params["config"] model = revnet.RevNet(config=config) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.piecewise_constant( global_step, config.lr_decay_steps, config.lr_list) optimizer = tf.train.MomentumOptimizer( learning_rate, momentum=config.momentum) logits, saved_hidden = model(inputs, training=True) grads, loss = model.compute_gradients(saved_hidden, labels, training=True) with tf.control_dependencies(model.get_updates_for(inputs)): train_op = optimizer.apply_gradients( zip(grads, model.trainable_variables), global_step=global_step) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) else: logits, _ = model(inputs, training=False) predictions = tf.argmax(logits, axis=1) probabilities = tf.nn.softmax(logits) if mode == tf.estimator.ModeKeys.EVAL: loss = model.compute_loss(labels=labels, logits=logits) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops={ "accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions) }) else: # mode == tf.estimator.ModeKeys.PREDICT result = { "classes": predictions, "probabilities": probabilities, } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ "classify": tf.estimator.export.PredictOutput(result) })
def test_training_graph(self): """Test model training in graph mode.""" with tf.Graph().as_default(): config = config_.get_hparams_cifar_38() config.add_hparam("n_classes", 10) config.add_hparam("dataset", "cifar-10") x = tf.random_normal( shape=(self.config.batch_size,) + self.config.input_shape) t = tf.random_uniform( shape=(self.config.batch_size,), minval=0, maxval=self.config.n_classes, dtype=tf.int32) global_step = tf.Variable(0., trainable=False) model = revnet.RevNet(config=config) _, saved_hidden = model(x) grads, _ = model.compute_gradients(saved_hidden=saved_hidden, labels=t) optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) train_op = optimizer.apply_gradients( zip(grads, model.trainable_variables), global_step=global_step) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(1): sess.run(train_op)
def setUp(self): super(RevnetTest, self).setUp() config = config_.get_hparams_imagenet_56() shape = (config.batch_size, ) + config.input_shape self.model = revnet.RevNet(config=config) self.x = tf.random_normal(shape=shape) self.t = tf.random_uniform(shape=[config.batch_size], minval=0, maxval=config.n_classes, dtype=tf.int32) self.config = config
def test_train_step_defun(self): self.model.call = tfe.defun(self.model.call) logits, _ = self.model(self.x, training=True) loss = self.model.compute_loss(logits=logits, labels=self.t) optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) for _ in range(3): loss_ = self.model.train_step(self.x, self.t, optimizer, report=True) self.assertTrue(loss_.numpy() <= loss.numpy()) loss = loss_ # Initialize new model, so that other tests are not affected self.model = revnet.RevNet(config=self.config)
def setUp(self): super(RevNetTest, self).setUp() config = config_.get_hparams_cifar_38() # Reconstruction could cause numerical error, use double precision for tests config.dtype = tf.float64 config.fused = False # Fused batch norm does not support tf.float64 shape = (config.batch_size, ) + config.input_shape self.model = revnet.RevNet(config=config) self.x = tf.random_normal(shape=shape, dtype=tf.float64) self.t = tf.random_uniform(shape=[config.batch_size], minval=0, maxval=config.n_classes, dtype=tf.int64) self.config = config
def setUp(self): super(RevNetTest, self).setUp() config = config_.get_hparams_cifar_38() config.add_hparam("n_classes", 10) config.add_hparam("dataset", "cifar-10") # Reconstruction could cause numerical error, use double precision for tests config.dtype = tf.float64 config.fused = False # Fused batch norm does not support tf.float64 # Reduce the batch size for tests because the OSS version runs # in constrained GPU environment with 1-2GB of memory. config.batch_size = 2 shape = (config.batch_size,) + config.input_shape self.model = revnet.RevNet(config=config) self.x = tf.random_normal(shape=shape, dtype=tf.float64) self.t = tf.random_uniform( shape=[config.batch_size], minval=0, maxval=config.n_classes, dtype=tf.int64) self.config = config
def _benchmark_eager_train(self, label, make_iterator, device_and_format, defun=False, execution_mode=None, compiled=False): config = config_.get_hparams_imagenet_56() config.add_hparam("n_classes", 1000) config.add_hparam("dataset", "ImageNet") with tfe.execution_mode(execution_mode): device, data_format = device_and_format for batch_size in self._train_batch_sizes(): (images, labels) = random_batch(batch_size, config) model = revnet.RevNet(config=config) optimizer = tf.train.GradientDescentOptimizer(0.1) if defun: model.call = tfe.defun(model.call) num_burn = 3 num_iters = 10 with tf.device(device): iterator = make_iterator((images, labels)) for _ in range(num_burn): (images, labels) = iterator.next() train_one_iter(model, images, labels, optimizer) if execution_mode: tfe.async_wait() self._force_device_sync() gc.collect() start = time.time() for _ in range(num_iters): (images, labels) = iterator.next() train_one_iter(model, images, labels, optimizer) if execution_mode: tfe.async_wait() self._force_device_sync() self._report(label, start, num_iters, device, batch_size, data_format)
def main(_): """Eager execution workflow with RevNet trained on CIFAR-10.""" if FLAGS.data_dir is None: raise ValueError("No supplied data directory") if not os.path.exists(FLAGS.data_dir): raise ValueError("Data directory {} does not exist".format( FLAGS.data_dir)) tf.enable_eager_execution() config = config_.get_hparams_cifar_38() model = revnet.RevNet(config=config) ds_train = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="train", data_aug=True, batch_size=config.batch_size, epochs=config.epochs, shuffle=config.shuffle, data_format=config.data_format, dtype=config.dtype, prefetch=config.prefetch) ds_validation = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="validation", data_aug=False, batch_size=config.eval_batch_size, epochs=1, data_format=config.data_format, dtype=config.dtype, prefetch=config.prefetch) ds_test = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="test", data_aug=False, batch_size=config.eval_batch_size, epochs=1, data_format=config.data_format, dtype=config.dtype, prefetch=config.prefetch) global_step = tfe.Variable(1, trainable=False) def learning_rate( ): # TODO(lxuechen): Remove once cl/201089859 is in place return tf.train.piecewise_constant(global_step, config.lr_decay_steps, config.lr_list) optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9) checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, optimizer_step=global_step) if FLAGS.train_dir: summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) if FLAGS.restore: latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) checkpoint.restore(latest_path) for x, y in ds_train: loss = train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step % config.log_every == 0: it_validation = ds_validation.make_one_shot_iterator() it_test = ds_test.make_one_shot_iterator() acc_validation = evaluate(model, it_validation) acc_test = evaluate(model, it_test) print("Iter {}, " "train loss {}, " "validation accuracy {}, " "test accuracy {}".format(global_step.numpy(), loss, acc_validation, acc_test)) if FLAGS.train_dir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): tf.contrib.summary.scalar("Validation accuracy", acc_validation) tf.contrib.summary.scalar("Test accuracy", acc_test) tf.contrib.summary.scalar("Training loss", loss) if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: checkpoint.save(file_prefix=FLAGS.train_dir + "ckpt")
def model_fn(features, labels, mode, params): """Model function required by the `tf.contrib.tpu.TPUEstimator` API. Args: features: Input images labels: Labels of images mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT' params: A dictionary of extra parameter that might be passed Returns: An instance of `tf.contrib.tpu.TPUEstimatorSpec` """ revnet_config = params["revnet_config"] model = revnet.RevNet(config=revnet_config) inputs = features if isinstance(inputs, dict): inputs = features["image"] if revnet_config.data_format == "channels_first": assert not FLAGS.transpose_input # channels_first only for GPU inputs = tf.transpose(inputs, [0, 3, 1, 2]) if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT: inputs = tf.transpose(inputs, [3, 0, 1, 2]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. inputs -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=inputs.dtype) inputs /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=inputs.dtype) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.piecewise_constant( global_step, revnet_config.lr_decay_steps, revnet_config.lr_list) optimizer = tf.train.MomentumOptimizer(learning_rate, revnet_config.momentum) if FLAGS.use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) logits, saved_hidden = model(inputs, training=True) grads, loss = model.compute_gradients(saved_hidden, labels, training=True) with tf.control_dependencies(model.get_updates_for(inputs)): train_op = optimizer.apply_gradients(zip( grads, model.trainable_variables), global_step=global_step) if not FLAGS.skip_host_call: # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) loss_t = tf.reshape(loss, [1]) lr_t = tf.reshape(learning_rate, [1]) host_call = (_host_call_fn, [gs_t, loss_t, lr_t]) return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, host_call=host_call) elif mode == tf.estimator.ModeKeys.EVAL: logits, _ = model(inputs, training=False) loss = model.compute_loss(labels=labels, logits=logits) return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=(_metric_fn, [labels, logits])) else: # Predict or export logits, _ = model(inputs, training=False) predictions = { "classes": tf.argmax(logits, axis=1), "probabilities": tf.nn.softmax(logits), } return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, export_outputs={ "classify": tf.estimator.export.PredictOutput(predictions) })
def model_fn(features, labels, mode, params): """Model function required by the `tf.contrib.tpu.TPUEstimator` API. Args: features: Input images labels: Labels of images mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT' params: A dictionary of extra parameter that might be passed Returns: An instance of `tf.contrib.tpu.TPUEstimatorSpec` """ inputs = features if isinstance(inputs, dict): inputs = features["image"] FLAGS = params["FLAGS"] # pylint:disable=invalid-name,redefined-outer-name config = params["config"] model = revnet.RevNet(config=config) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.piecewise_constant( global_step, config.lr_decay_steps, config.lr_list) optimizer = tf.train.MomentumOptimizer( learning_rate, momentum=config.momentum) if FLAGS.use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) # Define gradients grads, vars_, logits, loss = model.compute_gradients( inputs, labels, training=True) train_op = optimizer.apply_gradients( zip(grads, vars_), global_step=global_step) names = [v.name for v in model.variables] tf.logging.warn("{}".format(names)) return tf.contrib.tpu.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op) if mode == tf.estimator.ModeKeys.EVAL: logits, _ = model(inputs, training=False) loss = model.compute_loss(labels=labels, logits=logits) def metric_fn(labels, logits): predictions = tf.argmax(logits, axis=1) accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions) return { "accuracy": accuracy, } return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits])) if mode == tf.estimator.ModeKeys.PREDICT: logits, _ = model(inputs, training=False) predictions = { "classes": tf.argmax(logits, axis=1), "probabilities": tf.nn.softmax(logits), } return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, export_outputs={ "classify": tf.estimator.export.PredictOutput(predictions) })
def main(_): """Eager execution workflow with RevNet trained on CIFAR-10.""" tf.enable_eager_execution() config = get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets( data_dir=FLAGS.data_dir, config=config) model = revnet.RevNet(config=config) global_step = tf.train.get_or_create_global_step() # Ensure correct summary global_step.assign(1) learning_rate = tf.train.piecewise_constant( global_step, config.lr_decay_steps, config.lr_list) optimizer = tf.train.MomentumOptimizer( learning_rate, momentum=config.momentum) checkpointer = tf.train.Checkpoint( optimizer=optimizer, model=model, optimizer_step=global_step) if FLAGS.use_defun: model.call = tfe.defun(model.call) model.compute_gradients = tfe.defun(model.compute_gradients) model.get_moving_stats = tfe.defun(model.get_moving_stats) model.restore_moving_stats = tfe.defun(model.restore_moving_stats) global apply_gradients # pylint:disable=global-variable-undefined apply_gradients = tfe.defun(apply_gradients) if FLAGS.train_dir: summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) if FLAGS.restore: latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) checkpointer.restore(latest_path) print("Restored latest checkpoint at path:\"{}\" " "with global_step: {}".format(latest_path, global_step.numpy())) sys.stdout.flush() for x, y in ds_train: train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step.numpy() % config.log_every == 0: it_test = ds_test.make_one_shot_iterator() acc_test, loss_test = evaluate(model, it_test) if FLAGS.validate: it_train = ds_train_one_shot.make_one_shot_iterator() it_validation = ds_validation.make_one_shot_iterator() acc_train, loss_train = evaluate(model, it_train) acc_validation, loss_validation = evaluate(model, it_validation) print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " "validation set accuracy {:.4f}, loss {:.4f}; " "test accuracy {:.4f}, loss {:.4f}".format( global_step.numpy(), acc_train, loss_train, acc_validation, loss_validation, acc_test, loss_test)) else: print("Iter {}, test accuracy {:.4f}, loss {:.4f}".format( global_step.numpy(), acc_test, loss_test)) sys.stdout.flush() if FLAGS.train_dir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): tf.contrib.summary.scalar("Test accuracy", acc_test) tf.contrib.summary.scalar("Test loss", loss_test) if FLAGS.validate: tf.contrib.summary.scalar("Training accuracy", acc_train) tf.contrib.summary.scalar("Training loss", loss_train) tf.contrib.summary.scalar("Validation accuracy", acc_validation) tf.contrib.summary.scalar("Validation loss", loss_validation) if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: saved_path = checkpointer.save( file_prefix=os.path.join(FLAGS.train_dir, "ckpt")) print("Saved checkpoint at path: \"{}\" " "with global_step: {}".format(saved_path, global_step.numpy())) sys.stdout.flush()
def main(_): """Eager execution workflow with RevNet trained on CIFAR-10.""" if FLAGS.data_dir is None: raise ValueError("No supplied data directory") if not os.path.exists(FLAGS.data_dir): raise ValueError("Data directory {} does not exist".format( FLAGS.data_dir)) tf.enable_eager_execution() config = config_.get_hparams_cifar_38() if FLAGS.validate: # 40k Training set ds_train = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="train", data_aug=True, batch_size=config.batch_size, epochs=config.epochs, shuffle=config.shuffle, data_format=config.data_format, dtype=config.dtype, prefetch=config.batch_size) # 10k Training set ds_validation = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="validation", data_aug=False, batch_size=config.eval_batch_size, epochs=1, shuffle=False, data_format=config.data_format, dtype=config.dtype, prefetch=config.eval_batch_size) else: # 50k Training set ds_train = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="train_all", data_aug=True, batch_size=config.batch_size, epochs=config.epochs, shuffle=config.shuffle, data_format=config.data_format, dtype=config.dtype, prefetch=config.batch_size) # Always compute loss and accuracy on whole training and test set ds_train_one_shot = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="train_all", data_aug=False, batch_size=config.eval_batch_size, epochs=1, shuffle=False, data_format=config.data_format, dtype=config.dtype, prefetch=config.eval_batch_size) ds_test = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="test", data_aug=False, batch_size=config.eval_batch_size, epochs=1, shuffle=False, data_format=config.data_format, dtype=config.dtype, prefetch=config.eval_batch_size) model = revnet.RevNet(config=config) global_step = tfe.Variable(1, trainable=False) learning_rate = tf.train.piecewise_constant(global_step, config.lr_decay_steps, config.lr_list) optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=config.momentum) checkpointer = tf.train.Checkpoint(optimizer=optimizer, model=model, optimizer_step=global_step) if FLAGS.train_dir: summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) if FLAGS.restore: latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) checkpointer.restore(latest_path) print("Restored latest checkpoint at path:\"{}\" " "with global_step: {}".format(latest_path, global_step.numpy())) sys.stdout.flush() warmup(model, config) for x, y in ds_train: loss = train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step.numpy() % config.log_every == 0: it_train = ds_train_one_shot.make_one_shot_iterator() acc_train, loss_train = evaluate(model, it_train) it_test = ds_test.make_one_shot_iterator() acc_test, loss_test = evaluate(model, it_test) if FLAGS.validate: it_validation = ds_validation.make_one_shot_iterator() acc_validation, loss_validation = evaluate( model, it_validation) print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " "validation set accuracy {:.4f}, loss {:4.f}" "test accuracy {:.4f}, loss {:.4f}".format( global_step.numpy(), acc_train, loss_train, acc_validation, loss_validation, acc_test, loss_test)) else: print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " "test accuracy {:.4f}, loss {:.4f}".format( global_step.numpy(), acc_train, loss_train, acc_test, loss_test)) sys.stdout.flush() if FLAGS.train_dir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): tf.contrib.summary.scalar("Training loss", loss) tf.contrib.summary.scalar("Test accuracy", acc_test) if FLAGS.validate: tf.contrib.summary.scalar("Validation accuracy", acc_validation) if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: saved_path = checkpointer.save( file_prefix=os.path.join(FLAGS.train_dir, "ckpt")) print("Saved checkpoint at path: \"{}\" " "with global_step: {}".format(saved_path, global_step.numpy())) sys.stdout.flush()