def test_mirrored_strategy(self): ds = distribution_utils.get_distribution_strategy(5) self.assertFalse(ds.is_single_tower) self.assertEquals(ds.num_towers, 5) self.assertEquals(len(ds.worker_devices), 5) for device in ds.worker_devices: self.assertIn('GPU', device)
def __init__(self, strategy_type=None, strategy_config=None): _ = distribution_utils.configure_cluster(strategy_config.worker_hosts, strategy_config.task_index) self._strategy = distribution_utils.get_distribution_strategy( distribution_strategy=strategy_type, num_gpus=strategy_config.num_gpus, all_reduce_alg=strategy_config.all_reduce_alg, num_packs=strategy_config.num_packs, tpu_address=strategy_config.tpu)
def construct_estimator(flags_obj, params, schedule_manager): """Construct an estimator from either Estimator or TPUEstimator. Args: flags_obj: The FLAGS object parsed from command line. params: A dict of run specific parameters. schedule_manager: A schedule.Manager object containing the run schedule. Returns: An estimator object to be used for training and eval. """ print("============== all_reduce_alg ==============") print(flags_obj.all_reduce_alg) print("============== all_reduce_alg ==============") if not params["use_tpu"]: distribution_strategy = distribution_utils.get_distribution_strategy( flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) return tf.estimator.Estimator( model_fn=model_fn, model_dir=flags_obj.model_dir, params=params, config=tf.estimator.RunConfig( train_distribute=distribution_strategy)) tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( tpu=flags_obj.tpu, zone=flags_obj.tpu_zone, project=flags_obj.tpu_gcp_project) tpu_config = tf.contrib.tpu.TPUConfig( iterations_per_loop=schedule_manager.single_iteration_train_steps, num_shards=flags_obj.num_tpu_shards) run_config = tf.contrib.tpu.RunConfig(cluster=tpu_cluster_resolver, model_dir=flags_obj.model_dir, session_config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=True), tpu_config=tpu_config) return tf.contrib.tpu.TPUEstimator( model_fn=model_fn, use_tpu=params["use_tpu"] and flags_obj.tpu != tpu_util.LOCAL, train_batch_size=schedule_manager.batch_size, eval_batch_size=schedule_manager.batch_size, params={ # TPUEstimator needs to populate batch_size itself due to sharding. key: value for key, value in params.items() if key != "batch_size" }, config=run_config)
def test_one_device_strategy_gpu(self): ds = distribution_utils.get_distribution_strategy(num_gpus=1) self.assertEquals(ds.num_replicas_in_sync, 1) self.assertEquals(len(ds.extended.worker_devices), 1) self.assertIn('GPU', ds.extended.worker_devices[0])
def test_one_device_strategy_gpu(self): ds = distribution_utils.get_distribution_strategy(1) self.assertTrue(ds.is_single_tower) self.assertEquals(ds.num_towers, 1) self.assertEquals(len(ds.worker_devices), 1) self.assertIn('GPU', ds.worker_devices[0])
def run_mnist(flags_obj): """Run MNIST training and eval loop. Args: flags_obj: An object containing parsed flag values. """ model_helpers.apply_clean(flags_obj) model_function = model_fn session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_core.get_num_gpus(flags_obj), all_reduce_alg=flags_obj.all_reduce_alg) run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy, session_config=session_config) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') mnist_classifier = tf.estimator.Estimator(model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, params={ 'data_format': data_format, }) # Set up training and evaluation input functions. def train_input_fn(): """Prepare data for training.""" # When choosing shuffle buffer sizes, larger sizes result in better # randomness, while smaller sizes use less memory. MNIST is a small # enough dataset that we can easily shuffle the full epoch. ds = dataset.train(flags_obj.data_dir) ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size) # Iterate through the dataset a set number (`epochs_between_evals`) of times # during each training session. ds = ds.repeat(flags_obj.epochs_between_evals) return ds def eval_input_fn(): return dataset.test(flags_obj.data_dir).batch( flags_obj.batch_size).make_one_shot_iterator().get_next() # Set up hook that outputs training logs every 100 steps. train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) # Train and evaluate model. for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals): mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks) eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) print('\nEvaluation results:\n\t%s\n' % eval_results) if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['accuracy']): break # Export the model if flags_obj.export_dir is not None: image = tf.placeholder(tf.float32, [None, 28, 28]) input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ 'image': image, }) mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn, strip_default_attrs=True)
def resnet_main(flags_obj, model_function, input_function, dataset_name, shape=None): model_helpers.apply_clean(flags.FLAGS) os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy, session_config=session_config) if flags_obj.pretrained_model_checkpoint_path is not None: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start='^(?!.*dense)') else: warm_start_settings = None classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, warm_start_from=warm_start_settings, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj), 'fine_tune': flags_obj.fine_tune }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, } if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('resnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) def input_fn_train(num_epochs): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, num_gpus=flags_core.get_num_gpus(flags_obj), dtype=flags_core.get_tf_dtype(flags_obj)) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_tf_dtype(flags_obj)) if flags_obj.eval_only or not flags_obj.train_epochs: schedule, n_loops = [0], 1 else: n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals) schedule = [ flags_obj.epochs_between_evals for _ in range(int(n_loops)) ] schedule[-1] = flags_obj.train_epochs - sum( schedule[:-1]) # over counting. for cycle_index, num_train_epochs in enumerate(schedule): tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) if num_train_epochs: classifier.train(input_fn=lambda: input_fn_train(num_train_epochs), hooks=train_hooks, max_steps=flags_obj.max_train_steps) tf.logging.info('Starting to evaluate.') eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: dtype = flags_core.get_tf_dtype(flags_obj) input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size, dtype=dtype) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
def resnet_main(flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. """ model_helpers.apply_clean(flags.FLAGS) # Using the Winograd non-fused algorithms provides a small performance boost. os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # Create session config based on values of inter_op_parallelism_threads and # intra_op_parallelism_threads. Note that we default to having # allow_soft_placement = True, which is required for multi-GPU and not # harmful for other modes. session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy, session_config=session_config) # initialize our model with all but the dense layer from pretrained resnet if flags_obj.pretrained_model_checkpoint_path is not None: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start='^(?!.*dense)') else: warm_start_settings = None classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, warm_start_from=warm_start_settings, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj), 'fine_tune': flags_obj.fine_tune }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, } if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('resnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) def input_fn_train(num_epochs): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, num_gpus=flags_core.get_num_gpus(flags_obj), dtype=flags_core.get_tf_dtype(flags_obj)) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_tf_dtype(flags_obj)) if flags_obj.eval_only or not flags_obj.train_epochs: # If --eval_only is set, perform a single loop with zero train epochs. schedule, n_loops = [0], 1 else: # Compute the number of times to loop while training. All but the last # pass will train for `epochs_between_evals` epochs, while the last will # train for the number needed to reach `training_epochs`. For instance if # train_epochs = 25 and epochs_between_evals = 10 # schedule will be set to [10, 10, 5]. That is to say, the loop will: # Train for 10 epochs and then evaluate. # Train for another 10 epochs and then evaluate. # Train for a final 5 epochs (to reach 25 epochs) and then evaluate. n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals) schedule = [ flags_obj.epochs_between_evals for _ in range(int(n_loops)) ] schedule[-1] = flags_obj.train_epochs - sum( schedule[:-1]) # over counting. for cycle_index, num_train_epochs in enumerate(schedule): tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) if num_train_epochs: classifier.train(input_fn=lambda: input_fn_train(num_train_epochs), hooks=train_hooks, max_steps=flags_obj.max_train_steps) tf.logging.info('Starting to evaluate.') # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, which # will iterate forever. Passing steps=flags_obj.max_train_steps allows the # eval (which is generally unimportant in those circumstances) to terminate. # Note that eval will run for max_train_steps each loop, regardless of the # global_step count. eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. dtype = flags_core.get_tf_dtype(flags_obj) input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size, dtype=dtype) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
def resnet_main(flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. """ model_helpers.apply_clean(flags.FLAGS) # Using the Winograd non-fused algorithms provides a small performance boost. os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # Create session config based on values of inter_op_parallelism_threads and # intra_op_parallelism_threads. Note that we default to having # allow_soft_placement = True, which is required for multi-GPU and not # harmful for other modes. session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy, session_config=session_config) classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj) }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, } if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('resnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) def input_fn_train(): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=flags_obj.epochs_between_evals, num_gpus=flags_core.get_num_gpus(flags_obj)) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1) total_training_cycle = (flags_obj.train_epochs // flags_obj.epochs_between_evals) for cycle_index in range(total_training_cycle): tf.logging.info('Starting a training cycle: %d/%d', cycle_index, total_training_cycle) classifier.train(input_fn=input_fn_train, hooks=train_hooks, max_steps=flags_obj.max_train_steps) tf.logging.info('Starting to evaluate.') # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, which # will iterate forever. Passing steps=flags_obj.max_train_steps allows the # eval (which is generally unimportant in those circumstances) to terminate. # Note that eval will run for max_train_steps each loop, regardless of the # global_step count. eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
def resnet_main(flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. """ model_helpers.apply_clean(flags.FLAGS) # Ensures flag override logic is only executed if explicitly triggered. if flags_obj.tf_gpu_thread_mode: override_flags_and_set_envars_for_gpu_thread_pool(flags_obj) # Creates session config. allow_soft_placement = True, is required for # multi-GPU and is not harmful for other modes. session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) # Creates a `RunConfig` that checkpoints every 24 hours which essentially # results in checkpoints determined only by `epochs_between_evals`. run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy, session_config=session_config, save_checkpoints_secs=60 * 60 * 24) # Initializes model with all but the dense layer from pretrained ResNet. if flags_obj.pretrained_model_checkpoint_path is not None: if flags_obj.fine_tune: if string.lower(flags_obj.optimizer) == 'adam': if flags_obj.no_dense_init: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start=[ '^(?!.*(resnet_model/dense|beta1_power|beta2_power|Adam|global_step))' ]) # vars_to_warm_start=['^(?!.*(resnet_model/dense|global_step))']) else: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start=[ '^(?!.*(resnet_model/dense/kernel/Momentum|resnet_model/dense/bias/Momentum|beta1_power|beta2_power|Adam|global_step))' ]) # vars_to_warm_start=['^(?!.*(resnet_model/dense|global_step))']) else: if flags_obj.no_dense_init: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start=[ '^(?!.*(resnet_model/dense|Momentum|global_step))' ]) else: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start=[ '^(?!.*(resnet_model/dense/kernel/Momentum|resnet_model/dense/bias/Momentum|global_step))' ]) # vars_to_warm_start=['^(?!.*(resnet_model/dense|global_step))']) else: if string.lower(flags_obj.optimizer) == 'adam': warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start=[ '^(?!.*(endecoder|Momentum|beta1_power|beta2_power|global_step))' ]) # vars_to_warm_start='^(?!.*dense)') else: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start=['^(?!.*(endecoder|global_step))']) # vars_to_warm_start='^(?!.*dense)') else: warm_start_settings = None classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, warm_start_from=warm_start_settings, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj), 'fine_tune': flags_obj.fine_tune, 'reconst_loss_scale': flags_obj.reconst_loss_scale, 'use_ce': flags_obj.use_ce, 'optimizer': string.lower(flags_obj.optimizer), 'clip_grad': flags_obj.clip_grad, 'spectral_norm': flags_obj.spectral_norm, 'ce_scale': flags_obj.ce_scale, 'sep_grad_nrom': flags_obj.sep_grad_nrom, 'norm_teach_feature': flags_obj.norm_teach_feature, 'no_dense_init': flags_obj.no_dense_init, 'compress_ratio': flags_obj.compress_ratio }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, 'fine_tune': flags_obj.fine_tune, 'reconst_loss_scale': flags_obj.reconst_loss_scale, 'use_ce': flags_obj.use_ce, 'optimizer': string.lower(flags_obj.optimizer), 'clip_grad': flags_obj.clip_grad, 'spectral_norm': flags_obj.spectral_norm, 'ce_scale': flags_obj.ce_scale, 'sep_grad_nrom': flags_obj.sep_grad_nrom, 'norm_teach_feature': flags_obj.norm_teach_feature, 'no_dense_init': flags_obj.no_dense_init, 'compress_ratio': flags_obj.compress_ratio, } if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('resnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) def input_fn_train(num_epochs): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, dtype=flags_core.get_tf_dtype(flags_obj), datasets_num_private_threads=flags_obj. datasets_num_private_threads, num_parallel_batches=flags_obj.datasets_num_parallel_batches) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_tf_dtype(flags_obj)) if flags_obj.eval_only or not flags_obj.train_epochs: # If --eval_only is set, perform a single loop with zero train epochs. schedule, n_loops = [0], 1 else: # Compute the number of times to loop while training. All but the last # pass will train for `epochs_between_evals` epochs, while the last will # train for the number needed to reach `training_epochs`. For instance if # train_epochs = 25 and epochs_between_evals = 10 # schedule will be set to [10, 10, 5]. That is to say, the loop will: # Train for 10 epochs and then evaluate. # Train for another 10 epochs and then evaluate. # Train for a final 5 epochs (to reach 25 epochs) and then evaluate. n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals) schedule = [ flags_obj.epochs_between_evals for _ in range(int(n_loops)) ] schedule[-1] = flags_obj.train_epochs - sum( schedule[:-1]) # over counting. print('schedule: ', schedule, flags_obj.epochs_between_evals, flags_obj.max_train_steps) for cycle_index, num_train_epochs in enumerate(schedule): tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) if num_train_epochs: classifier.train(input_fn=lambda: input_fn_train(num_train_epochs), hooks=train_hooks, max_steps=flags_obj.max_train_steps) tf.logging.info('Starting to evaluate.') # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, which # will iterate forever. Passing steps=flags_obj.max_train_steps allows the # eval (which is generally unimportant in those circumstances) to terminate. # Note that eval will run for max_train_steps each loop, regardless of the # global_step count. eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. export_dtype = flags_core.get_tf_dtype(flags_obj) if flags_obj.image_bytes_as_serving_input: input_receiver_fn = functools.partial(image_bytes_serving_input_fn, shape, dtype=export_dtype) else: input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size, dtype=export_dtype) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn, strip_default_attrs=True)
def run_keras_model_benchmark(_): """Run the benchmark on keras model.""" # Ensure a valid model name was supplied via command line argument if FLAGS.model not in MODELS.keys(): raise AssertionError("The --model command line argument should " "be a key in the `MODELS` dictionary.") # Check if eager execution is enabled if FLAGS.eager: tf.logging.info("Eager execution is enabled...") tf.enable_eager_execution() # Load the model tf.logging.info("Benchmark on {} model...".format(FLAGS.model)) keras_model = MODELS[FLAGS.model] model = keras_model(weights=None) # Get dataset dataset_name = "ImageNet" if FLAGS.use_synthetic_data: tf.logging.info("Using synthetic dataset...") dataset_name += "_Synthetic" train_dataset = dataset.generate_synthetic_input_dataset( FLAGS.model, FLAGS.batch_size) val_dataset = dataset.generate_synthetic_input_dataset( FLAGS.model, FLAGS.batch_size) else: raise ValueError("Only synthetic dataset is supported!") num_gpus = flags_core.get_num_gpus(FLAGS) distribution = None # Use distribution strategy if FLAGS.dist_strat: distribution = distribution_utils.get_distribution_strategy( num_gpus=num_gpus) elif num_gpus > 1: # Run with multi_gpu_model # If eager execution is enabled, only one GPU is utilized even if multiple # GPUs are provided. if FLAGS.eager: tf.logging.warning( "{} GPUs are provided, but only one GPU is utilized as " "eager execution is enabled.".format(num_gpus)) model = tf.keras.utils.multi_gpu_model(model, gpus=num_gpus) # Adam optimizer and some other optimizers doesn't work well with # distribution strategy (b/113076709) # Use GradientDescentOptimizer here optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) model.compile(loss="categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"], distribute=distribution) # Create benchmark logger for benchmark logging run_params = { "batch_size": FLAGS.batch_size, "synthetic_data": FLAGS.use_synthetic_data, "train_epochs": FLAGS.train_epochs, "num_train_images": FLAGS.num_train_images, "num_eval_images": FLAGS.num_eval_images, } benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info(model_name=FLAGS.model, dataset_name=dataset_name, run_params=run_params, test_id=FLAGS.benchmark_test_id) # Create callbacks that log metric values about the training and evaluation callbacks = model_callbacks.get_model_callbacks( FLAGS.callbacks, batch_size=FLAGS.batch_size, metric_logger=benchmark_logger) # Train and evaluate the model history = model.fit(train_dataset, epochs=FLAGS.train_epochs, callbacks=callbacks, validation_data=val_dataset, steps_per_epoch=int( np.ceil(FLAGS.num_train_images / FLAGS.batch_size)), validation_steps=int( np.ceil(FLAGS.num_eval_images / FLAGS.batch_size))) tf.logging.info("Logging the evaluation results...") for epoch in range(FLAGS.train_epochs): eval_results = { "accuracy": history.history["val_acc"][epoch], "loss": history.history["val_loss"][epoch], tf.GraphKeys.GLOBAL_STEP: (epoch + 1) * np.ceil(FLAGS.num_eval_images / FLAGS.batch_size) } benchmark_logger.log_evaluation_result(eval_results) # Clear the session explicitly to avoid session delete error tf.keras.backend.clear_session()
def platform_main(train_data, val_data, class_num): # Parameter Get train_num = len(train_data[0]) val_num = len(val_data[0]) batch_size = int(FLAGS.param_batch_size) # Set up a RunConfig to only save checkpoints once per training cycle. session_config = tf.ConfigProto(allow_soft_placement=True) session_config.gpu_options.per_process_gpu_memory_fraction = 0.8 # Windows not support nccl, refer to https://github.com/tensorflow/tensorflow/issues/21470 # Only can use hierachical_copy distribution_strategy = distribution_utils.get_distribution_strategy( _GPU_NUM, 'hierachical_copy') run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy, session_config=session_config) # run_config = tf.estimator.RunConfig(session_config=session_config).replace(save_checkpoints_secs=1e9) model = tf.estimator.Estimator( model_fn=model_spec.model_fn, model_dir=FLAGS.Log_dir, config=run_config, params={ 'model_name': FLAGS.model_name, 'Log_dir': FLAGS.Log_dir, 'base_architecture': FLAGS.base_architecture, 'num_classes': class_num, 'crop_width': FLAGS.crop_size[0], 'crop_height': FLAGS.crop_size[1], 'batch_size': distribution_utils.per_device_batch_size(batch_size, _GPU_NUM), 'tensorboard_images_max_outputs': FLAGS.tensorboard_images_max_outputs, 'weight_decay': FLAGS.weight_decay, 'learning_rate_policy': FLAGS.learning_rate_policy, 'num_train': train_num, 'initial_learning_rate': FLAGS.initial_learning_rate, 'max_iter': FLAGS.max_iter, 'end_learning_rate': FLAGS.end_learning_rate, 'power': FLAGS.decay_power, 'freeze_batch_norm': FLAGS.freeze_batch_norm, 'initial_global_step': FLAGS.initial_global_step, 'class_weights': FLAGS.class_weights }) for _ in range(FLAGS.train_epochs // FLAGS.epochs_per_eval): tensors_to_log = { # 'global_step': 'global_step', 'learning_rate': 'learning_rate', 'cross_entropy': 'cross_entropy', 'train_px_accuracy': 'train_px_accuracy', 'train_mean_iou': 'train_mean_iou', } logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=10) train_hooks = [logging_hook] eval_hooks = None # if FLAGS.debug: # debug_hook = tf_debug.LocalCLIDebugHook() # train_hooks.append(debug_hook) # eval_hooks = [debug_hook] tf.logging.info("Start training.") model.train( input_fn=lambda: input_fn( True, train_data, distribution_utils.per_device_batch_size(batch_size, _GPU_NUM), train_num, FLAGS.epochs_per_eval), hooks=train_hooks, # steps=1 # For debug ) tf.logging.info("Start evaluation.") # Evaluate the model and print results eval_results = model.evaluate( # Batch size must be 1 for testing because the images' size differs input_fn=lambda: input_fn( False, val_data, batch_size=1, buffer_size=val_num), hooks=eval_hooks, # steps=1 # For debug ) tf.logging.info(eval_results)
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) resnet_common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = resnet_common.get_synth_input_fn( height=cifar10_preprocessing.HEIGHT, width=cifar10_preprocessing.WIDTH, num_channels=cifar10_preprocessing.NUM_CHANNELS, num_classes=cifar10_preprocessing.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj), drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = cifar10_preprocessing.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=cifar10_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, # Setting drop_remainder to avoid the partial batch logic in normalization # layer, which triggers tf.where and leads to extra memory copy of input # sizes between host and GPU. drop_remainder=(not flags_obj.enable_get_next_as_optional)) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=cifar10_preprocessing.parse_record) steps_per_epoch = (cifar10_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) lr_schedule = 0.1 if flags_obj.use_tensor_lr: initial_learning_rate = resnet_common.BASE_LEARNING_RATE * flags_obj.batch_size / 128 lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay( boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE), values=[initial_learning_rate] + list(p[0] * initial_learning_rate for p in LR_SCHEDULE)) with strategy_scope: optimizer = resnet_common.get_optimizer(lr_schedule) model = resnet_cifar_model.resnet56( classes=cifar10_preprocessing.NUM_CLASSES) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) train_epochs = flags_obj.train_epochs callbacks = resnet_common.get_callbacks(steps_per_epoch) if not flags_obj.use_tensor_lr: lr_callback = LearningRateBatchScheduler( schedule=learning_rate_schedule, batch_size=flags_obj.batch_size, steps_per_epoch=steps_per_epoch) callbacks.append(lr_callback) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = (cifar10_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=1) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = resnet_common.build_stats(history, eval_output, callbacks) return stats
def test_mirrored_strategy(self): ds = distribution_utils.get_distribution_strategy(num_gpus=5) self.assertEquals(ds.num_replicas_in_sync, 5) self.assertEquals(len(ds.extended.worker_devices), 5) for device in ds.extended.worker_devices: self.assertIn('GPU', device)
def train_and_eval( params: base_configs.ExperimentConfig, strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]: """Runs the train and eval path using compile/fit.""" logging.info('Running train and eval.') # Note: for TPUs, strategy and scope should be created before the dataset strategy = strategy_override or distribution_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) strategy_scope = distribution_utils.get_strategy_scope(strategy) logging.info('Detected %d devices.', strategy.num_replicas_in_sync if strategy else 1) label_smoothing = params.model.loss.label_smoothing one_hot = label_smoothing and label_smoothing > 0 builders = _get_dataset_builders(params, strategy, one_hot) datasets = [builder.build() if builder else None for builder in builders] # Unpack datasets and builders based on train/val/test splits train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking train_dataset, validation_dataset = datasets train_epochs = params.train.epochs train_steps = params.train.steps or train_builder.num_steps validation_steps = params.evaluation.steps or validation_builder.num_steps initialize(params, train_builder) logging.info('Global batch size: %d', train_builder.global_batch_size) with strategy_scope: model_params = params.model.model_params.as_dict() model = get_models()[params.model.name](**model_params) learning_rate = optimizer_factory.build_learning_rate( params=params.model.learning_rate, batch_size=train_builder.global_batch_size, train_steps=train_steps) optimizer = optimizer_factory.build_optimizer( optimizer_name=params.model.optimizer.name, base_learning_rate=learning_rate, params=params.model.optimizer.as_dict()) optimizer = performance.configure_optimizer( optimizer, use_float16=train_builder.dtype == 'float16', loss_scale=get_loss_scale(params)) metrics_map = _get_metrics(one_hot) metrics = [metrics_map[metric] for metric in params.train.metrics] if one_hot: loss_obj = tf.keras.losses.CategoricalCrossentropy( label_smoothing=params.model.loss.label_smoothing) else: loss_obj = tf.keras.losses.SparseCategoricalCrossentropy() model.compile(optimizer=optimizer, loss=loss_obj, metrics=metrics) initial_epoch = 0 if params.train.resume_checkpoint: initial_epoch = resume_from_checkpoint(model=model, model_dir=params.model_dir, train_steps=train_steps) serialize_config(params=params, model_dir=params.model_dir) # TODO(dankondratyuk): callbacks significantly slow down training callbacks = custom_callbacks.get_callbacks( model_checkpoint=params.train.callbacks.enable_checkpoint_and_export, include_tensorboard=params.train.callbacks.enable_tensorboard, time_history=params.train.callbacks.enable_time_history, track_lr=params.train.tensorboard.track_lr, write_model_weights=params.train.tensorboard.write_model_weights, initial_step=initial_epoch * train_steps, batch_size=train_builder.global_batch_size, log_steps=params.train.time_history.log_steps, model_dir=params.model_dir) if params.evaluation.skip_eval: validation_kwargs = {} else: validation_kwargs = { 'validation_data': validation_dataset, 'validation_steps': validation_steps, 'validation_freq': params.evaluation.epochs_between_evals, } model.summary() history = model.fit(train_dataset, epochs=train_epochs, steps_per_epoch=train_steps, initial_epoch=initial_epoch, callbacks=callbacks, **validation_kwargs) validation_output = None if not params.evaluation.skip_eval: validation_output = model.evaluate(validation_dataset, steps=validation_steps, verbose=2) # TODO(dankondratyuk): eval and save final test accuracy stats = resnet_common.build_stats(history, validation_output, callbacks) return stats
def run_mnist(flags_obj): """Run MNIST training and eval loop. Args: flags_obj: An object containing parsed flag values. """ model_helpers.apply_clean(flags_obj) model_function = model_fn session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) run_config = tf.estimator.RunConfig( train_distribute=distribution_strategy, session_config=session_config, save_checkpoints_steps=flags_obj.ckpt_steps, keep_checkpoint_max=flags_obj.max_ckpts, save_summary_steps=flags_obj.save_summary_steps, log_step_count_steps=flags_obj.log_step_count_steps ) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') mnist_classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, params={ 'data_format': data_format, }) # Set up training and evaluation input functions. def train_input_fn(): """Prepare data for training.""" # When choosing shuffle buffer sizes, larger sizes result in better # randomness, while smaller sizes use less memory. MNIST is a small # enough dataset that we can easily shuffle the full epoch. ds = dataset.train(flags_obj.data_dir) ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size) # Iterate through the dataset a set number (`epochs_between_evals`) of times # during each training session. ds = ds.repeat(flags_obj.epochs_between_evals) return ds def eval_input_fn(): return dataset.test(flags_obj.data_dir).batch( flags_obj.batch_size).make_one_shot_iterator().get_next() # Set up hook that outputs training logs every 100 steps. train_hooks = hooks_helper.get_train_hooks( flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, hooks=train_hooks, max_steps=flags_obj.max_steps) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=None, start_delay_secs=10, throttle_secs=flags_obj.eval_secs) tf.estimator.train_and_evaluate(mnist_classifier, train_spec, eval_spec) # Export the model if node is master and export_dir is set and if experiment is multinode - check if its master if os.environ.get('PS_CONFIG') and os.environ.get('TYPE') != 'master': tf.logging.debug('No model was exported') return if flags_obj.export_dir: tf.logging.debug('Starting to Export model to {}'.format(str(flags_obj.export_dir))) image = tf.placeholder(tf.float32, [None, 28, 28]) input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ 'image': image, }) mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn, strip_default_attrs=True) tf.logging.debug('Model Exported')
def resnet_main(flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. Returns: Dict of results of the run. Contains the keys `eval_results` and `train_hooks`. `eval_results` contains accuracy (top_1) and accuracy_top_5. `train_hooks` is a list the instances of hooks used during training. """ model_helpers.apply_clean(flags.FLAGS) # Ensures flag override logic is only executed if explicitly triggered. if flags_obj.tf_gpu_thread_mode: override_flags_and_set_envars_for_gpu_thread_pool(flags_obj) # Configures cluster spec for distribution strategy. num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) # Creates session config. allow_soft_placement = True, is required for # multi-GPU and is not harmful for other modes. session_config = tf.compat.v1.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_core.get_num_gpus(flags_obj), num_workers=num_workers, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) # Creates a `RunConfig` that checkpoints every 24 hours which essentially # results in checkpoints determined only by `epochs_between_evals`. run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy, session_config=session_config, save_checkpoints_secs=None, save_checkpoints_steps=2000) # Initializes model with all but the dense layer from pretrained ResNet. if flags_obj.pretrained_model_checkpoint_path is not None: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start='^(?!.*dense)') else: warm_start_settings = None classifier = tf.estimator.Estimator(model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, warm_start_from=warm_start_settings, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale( flags_obj, default_for_fp16=128), 'dtype': flags_core.get_tf_dtype(flags_obj), 'fine_tune': flags_obj.fine_tune, 'num_workers': num_workers, }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, 'num_workers': num_workers, } if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('resnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) def input_fn_train(num_epochs, input_context=None): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_replica_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, dtype=flags_core.get_tf_dtype(flags_obj), datasets_num_private_threads=flags_obj. datasets_num_private_threads, input_context=input_context) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_replica_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_tf_dtype(flags_obj)) train_epochs = (0 if flags_obj.eval_only or not flags_obj.train_epochs else flags_obj.train_epochs) use_train_and_evaluate = flags_obj.use_train_and_evaluate or num_workers > 1 if use_train_and_evaluate: train_spec = tf.estimator.TrainSpec( input_fn=lambda input_context=None: input_fn_train( train_epochs, input_context=input_context), hooks=train_hooks, max_steps=flags_obj.max_train_steps) eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval) tf.compat.v1.logging.info('Starting to train and evaluate.') tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec) # tf.estimator.train_and_evalute doesn't return anything in multi-worker # case. eval_results = {} else: if train_epochs == 0: # If --eval_only is set, perform a single loop with zero train epochs. schedule, n_loops = [0], 1 else: # Compute the number of times to loop while training. All but the last # pass will train for `epochs_between_evals` epochs, while the last will # train for the number needed to reach `training_epochs`. For instance if # train_epochs = 25 and epochs_between_evals = 10 # schedule will be set to [10, 10, 5]. That is to say, the loop will: # Train for 10 epochs and then evaluate. # Train for another 10 epochs and then evaluate. # Train for a final 5 epochs (to reach 25 epochs) and then evaluate. n_loops = math.ceil(train_epochs / flags_obj.epochs_between_evals) schedule = [ flags_obj.epochs_between_evals for _ in range(int(n_loops)) ] schedule[-1] = train_epochs - sum(schedule[:-1]) # over counting. for cycle_index, num_train_epochs in enumerate(schedule): tf.compat.v1.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) if num_train_epochs: # Since we are calling classifier.train immediately in each loop, the # value of num_train_epochs in the lambda function will not be changed # before it is used. So it is safe to ignore the pylint error here # pylint: disable=cell-var-from-loop classifier.train( input_fn=lambda input_context=None: input_fn_train( num_train_epochs, input_context=input_context), hooks=train_hooks, max_steps=flags_obj.max_train_steps) # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, # which will iterate forever. Passing steps=flags_obj.max_train_steps # allows the eval (which is generally unimportant in those circumstances) # to terminate. Note that eval will run for max_train_steps each loop, # regardless of the global_step count. tf.compat.v1.logging.info('Starting to evaluate.') eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. export_dtype = flags_core.get_tf_dtype(flags_obj) if flags_obj.image_bytes_as_serving_input: input_receiver_fn = functools.partial(image_bytes_serving_input_fn, shape, dtype=export_dtype) else: input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size, dtype=export_dtype) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn, strip_default_attrs=True) stats = {} stats['eval_results'] = eval_results stats['train_hooks'] = train_hooks return stats