def _run_bert_classifier(self, callbacks=None, use_ds=True): """Starts BERT classification task.""" with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs if self.num_steps_per_epoch: steps_per_epoch = self.num_steps_per_epoch else: train_data_size = input_meta_data['train_data_size'] steps_per_epoch = int(train_data_size / FLAGS.train_batch_size) warmup_steps = int(epochs * steps_per_epoch * 0.1) eval_steps = int( math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size)) if self.tpu: strategy = distribution_utils.get_distribution_strategy( distribution_strategy='tpu', tpu_address=self.tpu) else: strategy = distribution_utils.get_distribution_strategy( distribution_strategy='mirrored' if use_ds else 'off', num_gpus=self.num_gpus) steps_per_loop = 50 max_seq_length = input_meta_data['max_seq_length'] train_input_fn = run_classifier.get_dataset_fn(FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) eval_input_fn = run_classifier.get_dataset_fn(FLAGS.eval_data_path, max_seq_length, FLAGS.eval_batch_size, is_training=False) run_classifier.run_bert_classifier(strategy, bert_config, input_meta_data, FLAGS.model_dir, epochs, steps_per_epoch, steps_per_loop, eval_steps, warmup_steps, FLAGS.learning_rate, FLAGS.init_checkpoint, train_input_fn, eval_input_fn, custom_callbacks=callbacks)
def __init__(self, strategy_type=None, strategy_config=None): _ = distribution_utils.configure_cluster(strategy_config.worker_hosts, strategy_config.task_index) self._strategy = distribution_utils.get_distribution_strategy( distribution_strategy=strategy_type, num_gpus=strategy_config.num_gpus, all_reduce_alg=strategy_config.all_reduce_alg, num_packs=strategy_config.num_packs, tpu_address=strategy_config.tpu)
def _get_distribution_strategy(self, ds_type='mirrored'): """Gets the distribution strategy. Args: ds_type: String, the distribution strategy type to be used. Can be 'mirrored', 'multi_worker_mirrored', 'tpu' and 'off'. Returns: A `tf.distribute.DistibutionStrategy` object. """ if self.tpu or ds_type == 'tpu': return distribution_utils.get_distribution_strategy( distribution_strategy='tpu', tpu_address=self.tpu) elif ds_type == 'multi_worker_mirrored': # Configures cluster spec for multi-worker distribution strategy. _ = distribution_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index) return distribution_utils.get_distribution_strategy( distribution_strategy=ds_type, num_gpus=self.num_gpus, all_reduce_alg=FLAGS.all_reduce_alg)
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. NotImplementedError: If some features are not currently supported. Returns: Dictionary of training and eval stats. """ mlp_log.mlperf_print('init_start', None) common.print_flags(flags_obj) keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == tf.float16: loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_float16', loss_scale=loss_scale) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) if not keras_utils.is_v2_0(): raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.') elif dtype == tf.bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu, tpu_zone=flags_obj.tpu_zone if flags_obj.tpu else None) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=flags_obj.num_classes, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. # drop_remainder = flags_obj.enable_xla # Current resnet_model.resnet50 input format is always channel-last. # We use keras_application mobilenet model which input format is depends on # the keras beckend image data format. # This use_keras_image_data_format flags indicates whether image preprocessor # output format should be same as the keras backend image data format or just # channel-last format. use_keras_image_data_format = (flags_obj.model == 'mobilenet') train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=flags_obj.drop_train_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, training_dataset_cache=flags_obj.training_dataset_cache, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), dtype=dtype, drop_remainder=flags_obj.drop_eval_remainder) steps_per_epoch, train_epochs = common.get_num_train_iterations(flags_obj) mlp_log.mlperf_print('global_batch_size', flags_obj.batch_size) mlp_log.mlperf_print('num_train_examples', imagenet_preprocessing.NUM_IMAGES['train']) mlp_log.mlperf_print('num_eval_examples', imagenet_preprocessing.NUM_IMAGES['validation']) learning_rate_schedule_fn = None with strategy_scope: optimizer, learning_rate_schedule_fn = common.get_optimizer( flags_obj=flags_obj, steps_per_epoch=steps_per_epoch, train_steps=train_epochs * steps_per_epoch) if flags_obj.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) if flags_obj.model == 'resnet50_v1.5': resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers) model = resnet_model.resnet50(num_classes=flags_obj.num_classes) elif flags_obj.model == 'mobilenet': # TODO(kimjaehong): Remove layers attribute when minimum TF version # support 2.0 layers by default. model = tf.keras.applications.mobilenet.MobileNet( weights=None, classes=flags_obj.num_classes, layers=tf.keras.layers) if flags_obj.pretrained_filepath: model.load_weights(flags_obj.pretrained_filepath) if flags_obj.pruning_method == 'polynomial_decay': if dtype != tf.float32: raise NotImplementedError( 'Pruning is currently only supported on dtype=tf.float32.') pruning_params = { 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay( initial_sparsity=flags_obj.pruning_initial_sparsity, final_sparsity=flags_obj.pruning_final_sparsity, begin_step=flags_obj.pruning_begin_step, end_step=flags_obj.pruning_end_step, frequency=flags_obj.pruning_frequency), } model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params) elif flags_obj.pruning_method: raise NotImplementedError( 'Only polynomial_decay is currently supported.') # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) else: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) callbacks = common.get_callbacks( steps_per_epoch=steps_per_epoch, learning_rate_schedule_fn=learning_rate_schedule_fn, pruning_method=flags_obj.pruning_method, enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export, model_dir=flags_obj.model_dir) num_eval_steps = common.get_num_eval_steps(flags_obj) if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() mlp_log.mlperf_print('init_stop', None) mlp_log.mlperf_print('run_start', None) for epoch in range(train_epochs): mlp_log.mlperf_print('epoch_start', None, metadata={'first_epoch_num': epoch, 'epoch_count': 1}) mlp_log.mlperf_print('block_start', None) history = model.fit(train_input_dataset, epochs=1, steps_per_epoch=steps_per_epoch, callbacks=callbacks, verbose=2) mlp_log.mlperf_print('block_stop', None) eval_output = None if not flags_obj.skip_eval: mlp_log.mlperf_print('eval_start', None) eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) mlp_log.mlperf_print('eval_stop', None) eval_accuracy = eval_output[1] mlp_log.mlperf_print( 'eval_accuracy', eval_accuracy, metadata={'epoch_num': epoch}) if eval_accuracy >= flags_obj.target_accuracy: break mlp_log.mlperf_print('epoch_stop', None) mlp_log.mlperf_print('run_stop', None) if flags_obj.pruning_method: model = tfmot.sparsity.keras.strip_pruning(model) if flags_obj.enable_checkpoint_and_export: if dtype == tf.bfloat16: logging.warning('Keras model.save does not support bfloat16 dtype.') else: # Keras model.save assumes a float32 input designature. export_path = os.path.join(flags_obj.model_dir, 'saved_model') model.save(export_path, include_optimizer=False) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ mlp_log.mlperf_print('cache_clear', True) mlp_log.mlperf_print('init_start', None) mlp_log.mlperf_print('submission_benchmark', 'resnet') mlp_log.mlperf_print('submission_division', 'closed') mlp_log.mlperf_print('submission_org', 'google') mlp_log.mlperf_print( 'submission_platform', 'tpu-v3-{}'.format(flags_obj.num_replicas) if flags_obj.tpu else 'gpu-v100-{}'.format(flags_obj.num_gpus)) mlp_log.mlperf_print('submission_status', 'cloud') common.print_flags(flags_obj) keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) if tf.config.list_physical_devices('GPU'): if flags_obj.tf_gpu_thread_mode: datasets_num_private_threads = keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus) if not flags_obj.datasets_num_private_threads: flags_obj.datasets_num_private_threads = datasets_num_private_threads common.set_cudnn_batchnorm_mode() # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu, tpu_zone=flags_obj.tpu_zone if flags_obj.tpu else None) mlp_log.mlperf_print('global_batch_size', flags_obj.batch_size) mlp_log.mlperf_print('train_samples', imagenet_preprocessing.NUM_IMAGES['train']) mlp_log.mlperf_print('eval_samples', imagenet_preprocessing.NUM_IMAGES['validation']) mlp_log.mlperf_print( 'model_bn_span', int(flags_obj.batch_size / (flags_obj.num_replicas if flags_obj.tpu else flags_obj.num_gpus))) per_epoch_steps, train_epochs = common.get_num_train_iterations(flags_obj) eval_steps = common.get_num_eval_steps(flags_obj) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) logging.info( 'Training %d epochs, each epoch has %d steps, ' 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory( flags_obj.batch_size, flags_obj.log_steps, logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None) with distribution_utils.get_strategy_scope(strategy): runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback) eval_interval = (flags_obj.epochs_between_evals * per_epoch_steps if not flags_obj.skip_eval else None) eval_offset = (flags_obj.eval_offset_epochs * per_epoch_steps if not flags_obj.skip_eval else 0) if eval_offset != 0: eval_offset -= eval_interval checkpoint_interval = (per_epoch_steps if flags_obj.enable_checkpoint_and_export else None) summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=flags_obj.model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval) device_warmup_steps = (flags_obj.device_warmup_steps if flags_obj.enable_device_warmup else 0) if flags_obj.enable_device_warmup: logging.info('Warmup for %d steps.', device_warmup_steps) resnet_controller = controller.Controller( strategy, runnable.train, runnable.evaluate, runnable.warmup, global_step=runnable.global_step, steps_per_loop=steps_per_loop, train_steps=per_epoch_steps * train_epochs, device_warmup_steps=device_warmup_steps, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, eval_steps=eval_steps, eval_interval=eval_interval, eval_offset=eval_offset) if flags_obj.enable_device_warmup: resnet_controller.warmup() mlp_log.mlperf_print('init_stop', None) profile_steps = flags_obj.profile_steps if profile_steps: profile_steps = [int(i) for i in profile_steps.split(',')] if profile_steps[0] < 0: runnable.trace_start(-1) time_callback.on_train_begin() mlp_log.mlperf_print('run_start', None) mlp_log.mlperf_print( 'block_start', None, metadata={ 'first_epoch_num': 1, 'epoch_count': (flags_obj.eval_offset_epochs if flags_obj.eval_offset_epochs != 0 else flags_obj.epochs_between_evals) }) resnet_controller.train(evaluate=not flags_obj.skip_eval) mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) time_callback.on_train_end() mlp_log.mlperf_print('run_final', None) stats = build_stats(runnable, time_callback) return stats
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=cifar_preprocessing.HEIGHT, width=cifar_preprocessing.WIDTH, num_channels=cifar_preprocessing.NUM_CHANNELS, num_classes=cifar_preprocessing.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj), drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = cifar_preprocessing.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, # Setting drop_remainder to avoid the partial batch logic in normalization # layer, which triggers tf.where and leads to extra memory copy of input # sizes between host and GPU. drop_remainder=(not flags_obj.enable_get_next_as_optional)) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_preprocessing.parse_record) steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) lr_schedule = 0.1 if flags_obj.use_tensor_lr: initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128 lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay( boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE), values=[initial_learning_rate] + list(p[0] * initial_learning_rate for p in LR_SCHEDULE)) with strategy_scope: optimizer = common.get_optimizer(lr_schedule) model = resnet_cifar_model.resnet56( classes=cifar_preprocessing.NUM_CLASSES) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj. force_v2_in_keras_compile) else: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks(steps_per_epoch) if not flags_obj.use_tensor_lr: lr_callback = LearningRateBatchScheduler( schedule=learning_rate_schedule, batch_size=flags_obj.batch_size, steps_per_epoch=steps_per_epoch) callbacks.append(lr_callback) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats