def test_training_graph(self): """Test model training in graph mode.""" with tf.Graph().as_default(): config = config_.get_hparams_cifar_38() x = tf.random_normal(shape=(self.config.batch_size, ) + self.config.input_shape) t = tf.random_uniform(shape=(self.config.batch_size, ), minval=0, maxval=self.config.n_classes, dtype=tf.int32) global_step = tf.Variable(0., trainable=False) model = revnet.RevNet(config=config) model(x) updates = model.get_updates_for(x) x_ = tf.identity(x) grads_all, vars_all, _ = model.compute_gradients(x_, t, training=True) optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) with tf.control_dependencies(updates): train_op = optimizer.apply_gradients(zip(grads_all, vars_all), global_step=global_step) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(1): sess.run(train_op)
def test_training_graph(self): """Test model training in graph mode.""" with tf.Graph().as_default(): config = config_.get_hparams_cifar_38() config.add_hparam("n_classes", 10) config.add_hparam("dataset", "cifar-10") x = tf.random_normal( shape=(self.config.batch_size,) + self.config.input_shape) t = tf.random_uniform( shape=(self.config.batch_size,), minval=0, maxval=self.config.n_classes, dtype=tf.int32) global_step = tf.Variable(0., trainable=False) model = revnet.RevNet(config=config) grads_all, vars_all, _, _ = model.compute_gradients(x, t, training=True) optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) train_op = optimizer.apply_gradients( zip(grads_all, vars_all), global_step=global_step) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(1): sess.run(train_op)
def test_training_graph(self): """Test model training in graph mode.""" with tf.Graph().as_default(): config = config_.get_hparams_cifar_38() config.add_hparam("n_classes", 10) config.add_hparam("dataset", "cifar-10") x = tf.random_normal( shape=(self.config.batch_size,) + self.config.input_shape) t = tf.random_uniform( shape=(self.config.batch_size,), minval=0, maxval=self.config.n_classes, dtype=tf.int32) global_step = tf.Variable(0., trainable=False) model = revnet.RevNet(config=config) _, saved_hidden = model(x) grads, _ = model.compute_gradients(saved_hidden=saved_hidden, labels=t) optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) train_op = optimizer.apply_gradients( zip(grads, model.trainable_variables), global_step=global_step) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(1): sess.run(train_op)
def setUp(self): super(RevNetTest, self).setUp() config = config_.get_hparams_cifar_38() # Reconstruction could cause numerical error, use double precision for tests config.dtype = tf.float64 config.fused = False # Fused batch norm does not support tf.float64 shape = (config.batch_size, ) + config.input_shape self.model = revnet.RevNet(config=config) self.x = tf.random_normal(shape=shape, dtype=tf.float64) self.t = tf.random_uniform(shape=[config.batch_size], minval=0, maxval=config.n_classes, dtype=tf.int64) self.config = config
def get_config(): """Return configuration.""" print("Config: {}".format(FLAGS.config)) sys.stdout.flush() config = { "revnet-38": config_.get_hparams_cifar_38(), "revnet-110": config_.get_hparams_cifar_110(), "revnet-164": config_.get_hparams_cifar_164(), }[FLAGS.config] if FLAGS.dataset == "cifar-100": config.n_classes = 100 return config
def get_config(): """Return configuration.""" print("Config: {}".format(FLAGS.config)) sys.stdout.flush() config = { "revnet-38": config_.get_hparams_cifar_38(), "revnet-110": config_.get_hparams_cifar_110(), "revnet-164": config_.get_hparams_cifar_164(), }[FLAGS.config] if FLAGS.dataset == "cifar-100": config.n_classes = 100 return config
def setUp(self): super(RevNetTest, self).setUp() config = config_.get_hparams_cifar_38() # Reconstruction could cause numerical error, use double precision for tests config.dtype = tf.float64 config.fused = False # Fused batch norm does not support tf.float64 shape = (config.batch_size,) + config.input_shape self.model = revnet.RevNet(config=config) self.x = tf.random_normal(shape=shape, dtype=tf.float64) self.t = tf.random_uniform( shape=[config.batch_size], minval=0, maxval=config.n_classes, dtype=tf.int64) self.config = config
def get_config(config_name="revnet-38", dataset="cifar-10"): """Return configuration.""" print("Config: {}".format(config_name)) sys.stdout.flush() config = { "revnet-38": config_.get_hparams_cifar_38(), "revnet-110": config_.get_hparams_cifar_110(), "revnet-164": config_.get_hparams_cifar_164(), }[config_name] if dataset == "cifar-10": config.add_hparam("n_classes", 10) config.add_hparam("dataset", "cifar-10") else: config.add_hparam("n_classes", 100) config.add_hparam("dataset", "cifar-100") return config
def get_config(config_name="revnet-38", dataset="cifar-10"): """Return configuration.""" print("Config: {}".format(config_name)) sys.stdout.flush() config = { "revnet-38": config_.get_hparams_cifar_38(), "revnet-110": config_.get_hparams_cifar_110(), "revnet-164": config_.get_hparams_cifar_164(), }[config_name] if dataset == "cifar-10": config.add_hparam("n_classes", 10) config.add_hparam("dataset", "cifar-10") else: config.add_hparam("n_classes", 100) config.add_hparam("dataset", "cifar-100") return config
def setUp(self): super(RevNetTest, self).setUp() config = config_.get_hparams_cifar_38() config.add_hparam("n_classes", 10) config.add_hparam("dataset", "cifar-10") # Reconstruction could cause numerical error, use double precision for tests config.dtype = tf.float64 config.fused = False # Fused batch norm does not support tf.float64 # Reduce the batch size for tests because the OSS version runs # in constrained GPU environment with 1-2GB of memory. config.batch_size = 2 shape = (config.batch_size,) + config.input_shape self.model = revnet.RevNet(config=config) self.x = tf.random_normal(shape=shape, dtype=tf.float64) self.t = tf.random_uniform( shape=[config.batch_size], minval=0, maxval=config.n_classes, dtype=tf.int64) self.config = config
def main(_): """Eager execution workflow with RevNet trained on CIFAR-10.""" if FLAGS.data_dir is None: raise ValueError("No supplied data directory") if not os.path.exists(FLAGS.data_dir): raise ValueError("Data directory {} does not exist".format( FLAGS.data_dir)) tf.enable_eager_execution() config = config_.get_hparams_cifar_38() model = revnet.RevNet(config=config) ds_train = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="train", data_aug=True, batch_size=config.batch_size, epochs=config.epochs, shuffle=config.shuffle, data_format=config.data_format, dtype=config.dtype, prefetch=config.prefetch) ds_validation = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="validation", data_aug=False, batch_size=config.eval_batch_size, epochs=1, data_format=config.data_format, dtype=config.dtype, prefetch=config.prefetch) ds_test = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="test", data_aug=False, batch_size=config.eval_batch_size, epochs=1, data_format=config.data_format, dtype=config.dtype, prefetch=config.prefetch) global_step = tfe.Variable(1, trainable=False) def learning_rate( ): # TODO(lxuechen): Remove once cl/201089859 is in place return tf.train.piecewise_constant(global_step, config.lr_decay_steps, config.lr_list) optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9) checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, optimizer_step=global_step) if FLAGS.train_dir: summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) if FLAGS.restore: latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) checkpoint.restore(latest_path) for x, y in ds_train: loss = train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step % config.log_every == 0: it_validation = ds_validation.make_one_shot_iterator() it_test = ds_test.make_one_shot_iterator() acc_validation = evaluate(model, it_validation) acc_test = evaluate(model, it_test) print("Iter {}, " "train loss {}, " "validation accuracy {}, " "test accuracy {}".format(global_step.numpy(), loss, acc_validation, acc_test)) if FLAGS.train_dir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): tf.contrib.summary.scalar("Validation accuracy", acc_validation) tf.contrib.summary.scalar("Test accuracy", acc_test) tf.contrib.summary.scalar("Training loss", loss) if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: checkpoint.save(file_prefix=FLAGS.train_dir + "ckpt")
def main(_): """Eager execution workflow with RevNet trained on CIFAR-10.""" if FLAGS.data_dir is None: raise ValueError("No supplied data directory") if not os.path.exists(FLAGS.data_dir): raise ValueError("Data directory {} does not exist".format( FLAGS.data_dir)) tf.enable_eager_execution() config = config_.get_hparams_cifar_38() if FLAGS.validate: # 40k Training set ds_train = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="train", data_aug=True, batch_size=config.batch_size, epochs=config.epochs, shuffle=config.shuffle, data_format=config.data_format, dtype=config.dtype, prefetch=config.batch_size) # 10k Training set ds_validation = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="validation", data_aug=False, batch_size=config.eval_batch_size, epochs=1, shuffle=False, data_format=config.data_format, dtype=config.dtype, prefetch=config.eval_batch_size) else: # 50k Training set ds_train = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="train_all", data_aug=True, batch_size=config.batch_size, epochs=config.epochs, shuffle=config.shuffle, data_format=config.data_format, dtype=config.dtype, prefetch=config.batch_size) # Always compute loss and accuracy on whole training and test set ds_train_one_shot = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="train_all", data_aug=False, batch_size=config.eval_batch_size, epochs=1, shuffle=False, data_format=config.data_format, dtype=config.dtype, prefetch=config.eval_batch_size) ds_test = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="test", data_aug=False, batch_size=config.eval_batch_size, epochs=1, shuffle=False, data_format=config.data_format, dtype=config.dtype, prefetch=config.eval_batch_size) model = revnet.RevNet(config=config) global_step = tfe.Variable(1, trainable=False) learning_rate = tf.train.piecewise_constant(global_step, config.lr_decay_steps, config.lr_list) optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=config.momentum) checkpointer = tf.train.Checkpoint(optimizer=optimizer, model=model, optimizer_step=global_step) if FLAGS.train_dir: summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) if FLAGS.restore: latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) checkpointer.restore(latest_path) print("Restored latest checkpoint at path:\"{}\" " "with global_step: {}".format(latest_path, global_step.numpy())) sys.stdout.flush() warmup(model, config) for x, y in ds_train: loss = train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step.numpy() % config.log_every == 0: it_train = ds_train_one_shot.make_one_shot_iterator() acc_train, loss_train = evaluate(model, it_train) it_test = ds_test.make_one_shot_iterator() acc_test, loss_test = evaluate(model, it_test) if FLAGS.validate: it_validation = ds_validation.make_one_shot_iterator() acc_validation, loss_validation = evaluate( model, it_validation) print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " "validation set accuracy {:.4f}, loss {:4.f}" "test accuracy {:.4f}, loss {:.4f}".format( global_step.numpy(), acc_train, loss_train, acc_validation, loss_validation, acc_test, loss_test)) else: print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " "test accuracy {:.4f}, loss {:.4f}".format( global_step.numpy(), acc_train, loss_train, acc_test, loss_test)) sys.stdout.flush() if FLAGS.train_dir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): tf.contrib.summary.scalar("Training loss", loss) tf.contrib.summary.scalar("Test accuracy", acc_test) if FLAGS.validate: tf.contrib.summary.scalar("Validation accuracy", acc_validation) if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: saved_path = checkpointer.save( file_prefix=os.path.join(FLAGS.train_dir, "ckpt")) print("Saved checkpoint at path: \"{}\" " "with global_step: {}".format(saved_path, global_step.numpy())) sys.stdout.flush()
def main(_): """Eager execution workflow with RevNet trained on CIFAR-10.""" if FLAGS.data_dir is None: raise ValueError("No supplied data directory") if not os.path.exists(FLAGS.data_dir): raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir)) tf.enable_eager_execution() config = config_.get_hparams_cifar_38() if FLAGS.validate: # 40k Training set ds_train = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="train", data_aug=True, batch_size=config.batch_size, epochs=config.epochs, shuffle=config.shuffle, data_format=config.data_format, dtype=config.dtype, prefetch=config.batch_size) # 10k Training set ds_validation = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="validation", data_aug=False, batch_size=config.eval_batch_size, epochs=1, shuffle=False, data_format=config.data_format, dtype=config.dtype, prefetch=config.eval_batch_size) else: # 50k Training set ds_train = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="train_all", data_aug=True, batch_size=config.batch_size, epochs=config.epochs, shuffle=config.shuffle, data_format=config.data_format, dtype=config.dtype, prefetch=config.batch_size) # Always compute loss and accuracy on whole training and test set ds_train_one_shot = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="train_all", data_aug=False, batch_size=config.eval_batch_size, epochs=1, shuffle=False, data_format=config.data_format, dtype=config.dtype, prefetch=config.eval_batch_size) ds_test = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="test", data_aug=False, batch_size=config.eval_batch_size, epochs=1, shuffle=False, data_format=config.data_format, dtype=config.dtype, prefetch=config.eval_batch_size) model = revnet.RevNet(config=config) global_step = tfe.Variable(1, trainable=False) learning_rate = tf.train.piecewise_constant( global_step, config.lr_decay_steps, config.lr_list) optimizer = tf.train.MomentumOptimizer( learning_rate, momentum=config.momentum) checkpointer = tf.train.Checkpoint( optimizer=optimizer, model=model, optimizer_step=global_step) if FLAGS.train_dir: summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) if FLAGS.restore: latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) checkpointer.restore(latest_path) print("Restored latest checkpoint at path:\"{}\" " "with global_step: {}".format(latest_path, global_step.numpy())) sys.stdout.flush() warmup(model, config) for x, y in ds_train: loss = train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step.numpy() % config.log_every == 0: it_train = ds_train_one_shot.make_one_shot_iterator() acc_train, loss_train = evaluate(model, it_train) it_test = ds_test.make_one_shot_iterator() acc_test, loss_test = evaluate(model, it_test) if FLAGS.validate: it_validation = ds_validation.make_one_shot_iterator() acc_validation, loss_validation = evaluate(model, it_validation) print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " "validation set accuracy {:.4f}, loss {:4.f}" "test accuracy {:.4f}, loss {:.4f}".format( global_step.numpy(), acc_train, loss_train, acc_validation, loss_validation, acc_test, loss_test)) else: print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " "test accuracy {:.4f}, loss {:.4f}".format( global_step.numpy(), acc_train, loss_train, acc_test, loss_test)) sys.stdout.flush() if FLAGS.train_dir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): tf.contrib.summary.scalar("Training loss", loss) tf.contrib.summary.scalar("Test accuracy", acc_test) if FLAGS.validate: tf.contrib.summary.scalar("Validation accuracy", acc_validation) if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: saved_path = checkpointer.save( file_prefix=os.path.join(FLAGS.train_dir, "ckpt")) print("Saved checkpoint at path: \"{}\" " "with global_step: {}".format(saved_path, global_step.numpy())) sys.stdout.flush()
def main(_): """Eager execution workflow with RevNet trained on CIFAR-10.""" if FLAGS.data_dir is None: raise ValueError("No supplied data directory") if not os.path.exists(FLAGS.data_dir): raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir)) tf.enable_eager_execution() config = config_.get_hparams_cifar_38() model = revnet.RevNet(config=config) ds_train = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="train", data_aug=True, batch_size=config.batch_size, epochs=config.epochs, shuffle=config.shuffle, data_format=config.data_format, dtype=config.dtype, prefetch=config.prefetch) ds_validation = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="validation", data_aug=False, batch_size=config.eval_batch_size, epochs=1, data_format=config.data_format, dtype=config.dtype, prefetch=config.prefetch) ds_test = cifar_input.get_ds_from_tfrecords( data_dir=FLAGS.data_dir, split="test", data_aug=False, batch_size=config.eval_batch_size, epochs=1, data_format=config.data_format, dtype=config.dtype, prefetch=config.prefetch) global_step = tfe.Variable(1, trainable=False) def learning_rate(): # TODO(lxuechen): Remove once cl/201089859 is in place return tf.train.piecewise_constant(global_step, config.lr_decay_steps, config.lr_list) optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9) checkpoint = tf.train.Checkpoint( optimizer=optimizer, model=model, optimizer_step=global_step) if FLAGS.train_dir: summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) if FLAGS.restore: latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) checkpoint.restore(latest_path) for x, y in ds_train: loss = train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step % config.log_every == 0: it_validation = ds_validation.make_one_shot_iterator() it_test = ds_test.make_one_shot_iterator() acc_validation = evaluate(model, it_validation) acc_test = evaluate(model, it_test) print("Iter {}, " "train loss {}, " "validation accuracy {}, " "test accuracy {}".format(global_step.numpy(), loss, acc_validation, acc_test)) if FLAGS.train_dir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): tf.contrib.summary.scalar("Validation accuracy", acc_validation) tf.contrib.summary.scalar("Test accuracy", acc_test) tf.contrib.summary.scalar("Training loss", loss) if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: checkpoint.save(file_prefix=FLAGS.train_dir + "ckpt")