def main(unused_argv): global _cfg if layers_resnet.get_model(FLAGS.model) is None: raise RuntimeError('--model must be one of [' + ', '.join(layers_resnet.get_available_models()) + ']') if FLAGS.device not in ['CPU', 'GPU', 'TPU']: raise RuntimeError('--device must be one of [CPU, GPU, TPU]') if FLAGS.input_layout not in ['NCHW', 'NHWC']: raise RuntimeError('--input_layout must be one of [NCHW, NHWC]') if FLAGS.winograd_nonfused: os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' else: os.environ.pop('TF_ENABLE_WINOGRAD_NONFUSED', None) _cfg = ResnetConfig() setup_learning_rate_schedule() session_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement) if FLAGS.device == 'GPU': session_config.gpu_options.allow_growth = True config = tpu_config.RunConfig( save_checkpoints_secs=FLAGS.save_checkpoints_secs or None, save_summary_steps=FLAGS.save_summary_steps, log_step_count_steps=FLAGS.log_step_count_steps, master=FLAGS.master, model_dir=FLAGS.model_dir, tpu_config=tpu_config.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_shards, per_host_input_for_training=FLAGS.per_host_input_pipeline), session_config=session_config) if FLAGS.device == 'GPU' and FLAGS.num_shards > 1: run_on_gpu(config) return resnet_classifier = tpu_estimator.TPUEstimator( model_fn=get_model_fn(), use_tpu=FLAGS.device == 'TPU', config=config, train_batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.batch_size, batch_axis=(get_image_batch_axis(), 0)) pipeline_input_fn = get_input_pipeline_fn() def eval_input(params=None): return pipeline_input_fn(params=params, eval_batch_size=FLAGS.batch_size) model_conductor.conduct(resnet_classifier, pipeline_input_fn, eval_input, get_train_steps(), FLAGS.epochs_per_train * _cfg.batches_per_epoch, FLAGS.eval_steps, train_hooks=get_train_hooks(), target_accuracy=FLAGS.target_accuracy)
def multigpu_run(config, train_inputfn, eval_inputfn, modelfn, num_gpus, batch_size, shard_axis, weight_decay, momentum, learning_rate, train_steps, eval_steps, steps_per_train, target_accuracy=None, train_hooks=None): """Trains and evaluates a model on GPU. Args: config: The RunConfig object used to configure the Estimator used by the GPU model execution. train_inputfn: The input function used for training, which returns a tuple with the features and one-hot labels matching the features. eval_inputfn: The input function used for evaluation, which returns a tuple with the features and one-hot labels matching the features. modelfn: The core model function which builds the model computation graph. num_gpus: The number of GPU devices to shard the computation onto. batch_size: The global batch size. shard_axis: A tuple containing the tensor axis which should be used to shard inputs and labels respectively, across GPU devices. weight_decay: Weight regularization strength, a float. momentum: Momentum for MomentumOptimizer. learning_rate: Function object which build the learning rate graph. train_steps: The number of steps to be executed for training. eval_steps: The number of steps to be executed for evaluation. steps_per_train: How many training steps should be executed before running evaluation steps. target_accuracy: If specified, a given accuracy target at which to stop the training. train_hooks: Optional hooks for the training operation. """ def wrapped_modelfn(features, labels, mode): return multigpu_model_fn(features, labels, mode, modelfn, num_gpus, weight_decay, momentum, learning_rate) def train_input_function(): image_batch, label_batch = train_inputfn() return split_batch_input(image_batch, label_batch, batch_size, num_gpus, shard_axis) def eval_input_function(): image_batch, label_batch = eval_inputfn() return split_batch_input(image_batch, label_batch, batch_size, num_gpus, shard_axis) # Hooks that add extra logging that is useful to see the loss more often in # the console as well as examples per second. examples_sec_hook = ExamplesPerSecondHook(batch_size, every_n_steps=10) hooks = [examples_sec_hook] if train_hooks: hooks.extend(train_hooks) classifier = tf.estimator.Estimator(model_fn=wrapped_modelfn, config=config) model_conductor.conduct(classifier, train_input_function, eval_input_function, train_steps, steps_per_train, eval_steps, train_hooks=hooks, target_accuracy=target_accuracy)