def __init__(self): self.train_dataset = imagenet.get_split( 'train', FLAGS.data_dir, labels_dir=FLAGS.labels_dir, file_pattern=FLAGS.file_pattern) self.eval_dataset = imagenet.get_split('validation', FLAGS.data_dir, labels_dir=FLAGS.labels_dir, file_pattern=FLAGS.file_pattern) self.image_preprocessing_fn = vgg_preprocessing.preprocess_image model = layers_resnet.get_model(FLAGS.model) self.network_fn = model(num_classes=self.train_dataset.num_classes) self.batches_per_epoch = (self.train_dataset.num_samples / FLAGS.batch_size)
def main(unused_argv): global _cfg if layers_resnet.get_model(FLAGS.model) is None: raise RuntimeError('--model must be one of [' + ', '.join(layers_resnet.get_available_models()) + ']') if FLAGS.device not in ['CPU', 'GPU', 'TPU']: raise RuntimeError('--device must be one of [CPU, GPU, TPU]') if FLAGS.input_layout not in ['NCHW', 'NHWC']: raise RuntimeError('--input_layout must be one of [NCHW, NHWC]') if FLAGS.winograd_nonfused: os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' else: os.environ.pop('TF_ENABLE_WINOGRAD_NONFUSED', None) _cfg = ResnetConfig() setup_learning_rate_schedule() session_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement) if FLAGS.device == 'GPU': session_config.gpu_options.allow_growth = True config = tpu_config.RunConfig( save_checkpoints_secs=FLAGS.save_checkpoints_secs or None, save_summary_steps=FLAGS.save_summary_steps, log_step_count_steps=FLAGS.log_step_count_steps, master=FLAGS.master, model_dir=FLAGS.model_dir, tpu_config=tpu_config.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_shards, per_host_input_for_training=FLAGS.per_host_input_pipeline), session_config=session_config) if FLAGS.device == 'GPU' and FLAGS.num_shards > 1: run_on_gpu(config) return resnet_classifier = tpu_estimator.TPUEstimator( model_fn=get_model_fn(), use_tpu=FLAGS.device == 'TPU', config=config, train_batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.batch_size, batch_axis=(get_image_batch_axis(), 0)) pipeline_input_fn = get_input_pipeline_fn() def eval_input(params=None): return pipeline_input_fn(params=params, eval_batch_size=FLAGS.batch_size) model_conductor.conduct(resnet_classifier, pipeline_input_fn, eval_input, get_train_steps(), FLAGS.epochs_per_train * _cfg.batches_per_epoch, FLAGS.eval_steps, train_hooks=get_train_hooks(), target_accuracy=FLAGS.target_accuracy)
def main(unused_argv): global cfg if layers_resnet.get_model(FLAGS.model) is None: raise RuntimeError('--model must be one of [' + ', '.join(layers_resnet.get_available_models()) + ']') if FLAGS.device not in ['CPU', 'GPU', 'TPU']: raise RuntimeError('--device must be one of [CPU, GPU, TPU]') if FLAGS.input_layout not in ['NCHW', 'NHWC']: raise RuntimeError('--device must be one of [NCHW, NHWC]') if FLAGS.winograd_nonfused: os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' else: os.environ.pop('TF_ENABLE_WINOGRAD_NONFUSED', None) cfg = ResnetConfig() hooks = None session_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement) if FLAGS.device == 'GPU': session_config.gpu_options.allow_growth = True if FLAGS.device != 'TPU': # Hooks do not work on TPU at the moment. tensors_to_log = {'learning_rate': 'learning_rate'} logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=100) hooks = [logging_hook] config = tpu_config.RunConfig( save_checkpoints_secs=FLAGS.save_checkpoints_secs or None, save_summary_steps=FLAGS.save_summary_steps, log_step_count_steps=FLAGS.log_step_count_steps, master=FLAGS.master, model_dir=FLAGS.model_dir, tpu_config=tpu_config.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_shards), session_config=session_config) resnet_classifier = tpu_estimator.TPUEstimator( model_fn=resnet_model_fn, use_tpu=FLAGS.device == 'TPU', config=config, train_batch_size=FLAGS.batch_size) print('Starting to train...') if FLAGS.train_epochs: train_steps = FLAGS.train_epochs * cfg.batches_per_epoch else: train_steps = FLAGS.train_steps resnet_classifier.train(input_fn=input_fn, max_steps=train_steps, hooks=hooks) if FLAGS.eval_steps > 0: def eval_input(params=None): return input_fn(params=params, eval_batch_size=FLAGS.batch_size) print('Starting to evaluate...') eval_results = resnet_classifier.evaluate(input_fn=eval_input, steps=FLAGS.eval_steps) print(eval_results)