def main(argv): FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name tf.logging.set_verbosity(tf.logging.INFO) # RevNet specific configuration config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) # Estimator specific configuration run_config = tf.estimator.RunConfig( model_dir=FLAGS.train_dir, # Directory for storing checkpoints tf_random_seed=config.seed, save_summary_steps=config.log_every, save_checkpoints_steps=config.log_every, session_config=None, # Using default keep_checkpoint_max=100, keep_checkpoint_every_n_hours=10000, # Using default log_step_count_steps=config.log_every, train_distribute=None # Default not use distribution strategy ) # Construct estimator revnet_estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=FLAGS.train_dir, config=run_config, params={"config": config}) # Construct input functions train_input_fn = get_input_fn(config=config, data_dir=FLAGS.data_dir, split="train_all") eval_input_fn = get_input_fn(config=config, data_dir=FLAGS.data_dir, split="test") # Train and evaluate estimator revnet_estimator.train(input_fn=train_input_fn) revnet_estimator.evaluate(input_fn=eval_input_fn) if FLAGS.export: input_shape = (None, ) + config.input_shape inputs = tf.placeholder(tf.float32, shape=input_shape) input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn( {"image": inputs}) revnet_estimator.export_savedmodel(FLAGS.train_dir, input_fn)
def main(argv): FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name tf.logging.set_verbosity(tf.logging.INFO) # RevNet specific configuration config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) # Estimator specific configuration run_config = tf.estimator.RunConfig( model_dir=FLAGS.train_dir, # Directory for storing checkpoints tf_random_seed=config.seed, save_summary_steps=config.log_every, save_checkpoints_steps=config.log_every, session_config=None, # Using default keep_checkpoint_max=100, keep_checkpoint_every_n_hours=10000, # Using default log_step_count_steps=config.log_every, train_distribute=None # Default not use distribution strategy ) # Construct estimator revnet_estimator = tf.estimator.Estimator( model_fn=model_fn, model_dir=FLAGS.train_dir, config=run_config, params={"config": config}) # Construct input functions train_input_fn = get_input_fn( config=config, data_dir=FLAGS.data_dir, split="train_all") eval_input_fn = get_input_fn( config=config, data_dir=FLAGS.data_dir, split="test") # Train and evaluate estimator revnet_estimator.train(input_fn=train_input_fn) revnet_estimator.evaluate(input_fn=eval_input_fn) if FLAGS.export: input_shape = (None,) + config.input_shape inputs = tf.placeholder(tf.float32, shape=input_shape) input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ "image": inputs }) revnet_estimator.export_savedmodel(FLAGS.train_dir, input_fn)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) # RevNet specific configuration config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) if FLAGS.use_tpu: tf.logging.info("Using TPU.") tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) else: tpu_cluster_resolver = None # TPU specific configuration tpu_config = tf.contrib.tpu.TPUConfig( # Recommended to be set as number of global steps for next checkpoint iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_shards) # Estimator specific configuration run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=FLAGS.model_dir, session_config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=False), tpu_config=tpu_config, ) # Construct TPU Estimator estimator = tf.contrib.tpu.TPUEstimator( model_fn=model_fn, use_tpu=FLAGS.use_tpu, train_batch_size=config.tpu_batch_size, eval_batch_size=config.tpu_eval_batch_size, config=run_config, params={"config": config}) # Construct input functions train_input_fn = get_input_fn( config=config, data_dir=FLAGS.data_dir, split="train_all") eval_input_fn = get_input_fn( config=config, data_dir=FLAGS.data_dir, split="test") # Disabling a range within an else block currently doesn't work # due to https://github.com/PyCQA/pylint/issues/872 # pylint: disable=protected-access if FLAGS.mode == "eval": # TPUEstimator.evaluate *requires* a steps argument. # Note that the number of examples used during evaluation is # --eval_steps * --batch_size. # So if you change --batch_size then change --eval_steps too. eval_steps = 10000 // config.tpu_eval_batch_size # Run evaluation when there's a new checkpoint for ckpt in evaluation.checkpoints_iterator( FLAGS.model_dir, timeout=FLAGS.eval_timeout): tf.logging.info("Starting to evaluate.") try: start_timestamp = time.time() # This time will include compilation time eval_results = estimator.evaluate( input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=ckpt) elapsed_time = int(time.time() - start_timestamp) tf.logging.info("Eval results: %s. Elapsed seconds: %d" % (eval_results, elapsed_time)) # Terminate eval job when final checkpoint is reached current_step = int(os.path.basename(ckpt).split("-")[1]) if current_step >= config.max_train_iter: tf.logging.info( "Evaluation finished after training step %d" % current_step) break except tf.errors.NotFoundError: # Since the coordinator is on a different job than the TPU worker, # sometimes the TPU worker does not finish initializing until long after # the CPU job tells it to start evaluating. In this case, the checkpoint # file could have been deleted already. tf.logging.info( "Checkpoint %s no longer exists, skipping checkpoint" % ckpt) else: # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval' current_step = estimator_._load_global_step_from_checkpoint_dir( FLAGS.model_dir) tf.logging.info("Training for %d steps . Current" " step %d." % (config.max_train_iter, current_step)) start_timestamp = time.time() # This time will include compilation time if FLAGS.mode == "train": estimator.train(input_fn=train_input_fn, max_steps=config.max_train_iter) else: eval_steps = 10000 // config.tpu_eval_batch_size assert FLAGS.mode == "train_and_eval" while current_step < config.max_train_iter: # Train for up to steps_per_eval number of steps. # At the end of training, a checkpoint will be written to --model_dir. next_checkpoint = min(current_step + FLAGS.steps_per_eval, config.max_train_iter) estimator.train(input_fn=train_input_fn, max_steps=next_checkpoint) current_step = next_checkpoint # Evaluate the model on the most recent model in --model_dir. # Since evaluation happens in batches of --eval_batch_size, some images # may be consistently excluded modulo the batch size. tf.logging.info("Starting to evaluate.") eval_results = estimator.evaluate( input_fn=eval_input_fn, steps=eval_steps) tf.logging.info("Eval results: %s" % eval_results) elapsed_time = int(time.time() - start_timestamp) tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % (config.max_train_iter, elapsed_time))
def main(argv): FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name tf.logging.set_verbosity(tf.logging.INFO) # RevNet specific configuration config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) if FLAGS.use_tpu: tf.logging.info("Using TPU.") tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) else: tpu_cluster_resolver = None # TPU specific configuration tpu_config = tf.contrib.tpu.TPUConfig( # Recommended to be set as number of global steps for next checkpoint iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_shards) # Estimator specific configuration run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=FLAGS.model_dir, session_config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=False), tpu_config=tpu_config, ) # Construct TPU Estimator estimator = tf.contrib.tpu.TPUEstimator( model_fn=model_fn, use_tpu=FLAGS.use_tpu, train_batch_size=config.tpu_batch_size, eval_batch_size=config.tpu_eval_batch_size, config=run_config, params={ "FLAGS": FLAGS, "config": config, }) # Construct input functions train_input_fn = get_input_fn( config=config, data_dir=FLAGS.data_dir, split="train_all") eval_input_fn = get_input_fn( config=config, data_dir=FLAGS.data_dir, split="test") # Disabling a range within an else block currently doesn't work # due to https://github.com/PyCQA/pylint/issues/872 # pylint: disable=protected-access if FLAGS.mode == "eval": # TPUEstimator.evaluate *requires* a steps argument. # Note that the number of examples used during evaluation is # --eval_steps * --batch_size. # So if you change --batch_size then change --eval_steps too. eval_steps = 10000 // config.tpu_eval_batch_size # Run evaluation when there's a new checkpoint for ckpt in evaluation.checkpoints_iterator( FLAGS.model_dir, timeout=FLAGS.eval_timeout): tf.logging.info("Starting to evaluate.") try: start_timestamp = time.time() # This time will include compilation time eval_results = estimator.evaluate( input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=ckpt) elapsed_time = int(time.time() - start_timestamp) tf.logging.info("Eval results: %s. Elapsed seconds: %d" % (eval_results, elapsed_time)) # Terminate eval job when final checkpoint is reached current_step = int(os.path.basename(ckpt).split("-")[1]) if current_step >= config.max_train_iter: tf.logging.info( "Evaluation finished after training step %d" % current_step) break except tf.errors.NotFoundError: # Since the coordinator is on a different job than the TPU worker, # sometimes the TPU worker does not finish initializing until long after # the CPU job tells it to start evaluating. In this case, the checkpoint # file could have been deleted already. tf.logging.info( "Checkpoint %s no longer exists, skipping checkpoint" % ckpt) else: # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval' current_step = estimator_._load_global_step_from_checkpoint_dir( FLAGS.model_dir) tf.logging.info("Training for %d steps . Current" " step %d." % (config.max_train_iter, current_step)) start_timestamp = time.time() # This time will include compilation time if FLAGS.mode == "train": estimator.train(input_fn=train_input_fn, max_steps=config.max_train_iter) else: eval_steps = 10000 // config.tpu_eval_batch_size assert FLAGS.mode == "train_and_eval" while current_step < config.max_train_iter: # Train for up to steps_per_eval number of steps. # At the end of training, a checkpoint will be written to --model_dir. next_checkpoint = min(current_step + FLAGS.steps_per_eval, config.max_train_iter) estimator.train(input_fn=train_input_fn, max_steps=next_checkpoint) current_step = next_checkpoint # Evaluate the model on the most recent model in --model_dir. # Since evaluation happens in batches of --eval_batch_size, some images # may be consistently excluded modulo the batch size. tf.logging.info("Starting to evaluate.") eval_results = estimator.evaluate( input_fn=eval_input_fn, steps=eval_steps) tf.logging.info("Eval results: %s" % eval_results) elapsed_time = int(time.time() - start_timestamp) tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % (config.max_train_iter, elapsed_time))