def _run_callbacks_on_batch_end(batch, logs): """Runs custom callbacks at the end of every step.""" mlp_log.mlperf_print( 'block_stop', None, metadata={ 'first_epoch_num': int(batch), }) if not custom_callbacks: return for callback in custom_callbacks: callback.on_batch_end(batch, logs)
def eval_begin(self): """See base class.""" if self.test_loss: self.test_loss.reset_states() if self.test_accuracy: self.test_accuracy.reset_states() # self.test_corrects.reset_states() epoch_num = int(self.epoch_helper.current_epoch) mlp_log.mlperf_print('eval_start', None, metadata={'epoch_num': epoch_num + 1})
def _run_callbacks_on_batch_begin(batch): """Runs custom callbacks at the start of every step.""" # While BERT pretraining does not have epochs, # to make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps. mlp_log.mlperf_print( 'block_start', None, metadata={ 'first_epoch_num': int(batch), 'epoch_count': int(steps_per_loop), }) if not custom_callbacks: return for callback in custom_callbacks: callback.on_batch_begin(batch)
def _run_evaluation(current_training_step, test_iterator): """Runs validation steps and aggregate metrics.""" mlperf_epoch_num = int(current_training_step / steps_between_eval) mlp_log.mlperf_print( 'eval_start', None, metadata={'epoch_num': mlperf_epoch_num}) for _ in range(eval_steps): test_step(test_iterator) mlp_log.mlperf_print( 'eval_stop', None, metadata={'epoch_num': mlperf_epoch_num}) with eval_summary_writer.as_default(): masked_lm_accuracy = ( _float_metric_value(eval_metric_num) / _float_metric_value(eval_metric_denom)) logging.info('Step: [%d] Validation %s = %f', current_training_step, 'masked_lm_accuracy', masked_lm_accuracy) tf.summary.scalar( 'masked_lm_accuracy', masked_lm_accuracy, step=current_training_step) mlp_log.mlperf_print( 'eval_accuracy', masked_lm_accuracy, metadata={'epoch_num': mlperf_epoch_num}) eval_summary_writer.flush() return masked_lm_accuracy
def eval_end(self): """See base class.""" epoch_num = int(self.epoch_helper.current_epoch) mlp_log.mlperf_print('eval_stop', None, metadata={'epoch_num': epoch_num + 1}) eval_accuracy = float(self.test_accuracy.result()) # eval_accuracy = float(self.test_corrects.result() # ) / imagenet_preprocessing.NUM_IMAGES['validation'] # eval_accuracy = float(self.test_accuracy.result()) * \ # self.flags_obj.batch_size * self.num_eval_steps / \ # imagenet_preprocessing.NUM_IMAGES['validation'] mlp_log.mlperf_print('eval_accuracy', eval_accuracy, metadata={'epoch_num': epoch_num + 1}) first_epoch_num = max(epoch_num - self.epochs_between_evals + 1, 0) epoch_count = self.epochs_between_evals if first_epoch_num == 0: epoch_count = self.flags_obj.eval_offset_epochs if epoch_count == 0: epoch_count = self.flags_obj.epochs_between_evals mlp_log.mlperf_print('block_stop', None, metadata={ 'first_epoch_num': first_epoch_num + 1, 'epoch_count': epoch_count }) continue_training = True if (eval_accuracy >= self.flags_obj.target_accuracy or eval_accuracy <= 0.002): continue_training = False else: mlp_log.mlperf_print('block_start', None, metadata={ 'first_epoch_num': epoch_num + 2, 'epoch_count': self.epochs_between_evals }) results = {} if self.test_loss: results['test_loss'] = self.test_loss.result() if self.test_accuracy: results['test_accuracy'] = self.test_accuracy.result() results['continue_training'] = continue_training return results
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. NotImplementedError: If some features are not currently supported. Returns: Dictionary of training and eval stats. """ mlp_log.mlperf_print('init_start', None) common.print_flags(flags_obj) keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == tf.float16: loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_float16', loss_scale=loss_scale) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) if not keras_utils.is_v2_0(): raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.') elif dtype == tf.bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu, tpu_zone=flags_obj.tpu_zone if flags_obj.tpu else None) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=flags_obj.num_classes, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. # drop_remainder = flags_obj.enable_xla # Current resnet_model.resnet50 input format is always channel-last. # We use keras_application mobilenet model which input format is depends on # the keras beckend image data format. # This use_keras_image_data_format flags indicates whether image preprocessor # output format should be same as the keras backend image data format or just # channel-last format. use_keras_image_data_format = (flags_obj.model == 'mobilenet') train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=flags_obj.drop_train_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, training_dataset_cache=flags_obj.training_dataset_cache, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), dtype=dtype, drop_remainder=flags_obj.drop_eval_remainder) steps_per_epoch, train_epochs = common.get_num_train_iterations(flags_obj) mlp_log.mlperf_print('global_batch_size', flags_obj.batch_size) mlp_log.mlperf_print('num_train_examples', imagenet_preprocessing.NUM_IMAGES['train']) mlp_log.mlperf_print('num_eval_examples', imagenet_preprocessing.NUM_IMAGES['validation']) learning_rate_schedule_fn = None with strategy_scope: optimizer, learning_rate_schedule_fn = common.get_optimizer( flags_obj=flags_obj, steps_per_epoch=steps_per_epoch, train_steps=train_epochs * steps_per_epoch) if flags_obj.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) if flags_obj.model == 'resnet50_v1.5': resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers) model = resnet_model.resnet50(num_classes=flags_obj.num_classes) elif flags_obj.model == 'mobilenet': # TODO(kimjaehong): Remove layers attribute when minimum TF version # support 2.0 layers by default. model = tf.keras.applications.mobilenet.MobileNet( weights=None, classes=flags_obj.num_classes, layers=tf.keras.layers) if flags_obj.pretrained_filepath: model.load_weights(flags_obj.pretrained_filepath) if flags_obj.pruning_method == 'polynomial_decay': if dtype != tf.float32: raise NotImplementedError( 'Pruning is currently only supported on dtype=tf.float32.') pruning_params = { 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay( initial_sparsity=flags_obj.pruning_initial_sparsity, final_sparsity=flags_obj.pruning_final_sparsity, begin_step=flags_obj.pruning_begin_step, end_step=flags_obj.pruning_end_step, frequency=flags_obj.pruning_frequency), } model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params) elif flags_obj.pruning_method: raise NotImplementedError( 'Only polynomial_decay is currently supported.') # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) else: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) callbacks = common.get_callbacks( steps_per_epoch=steps_per_epoch, learning_rate_schedule_fn=learning_rate_schedule_fn, pruning_method=flags_obj.pruning_method, enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export, model_dir=flags_obj.model_dir) num_eval_steps = common.get_num_eval_steps(flags_obj) if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() mlp_log.mlperf_print('init_stop', None) mlp_log.mlperf_print('run_start', None) for epoch in range(train_epochs): mlp_log.mlperf_print('epoch_start', None, metadata={'first_epoch_num': epoch, 'epoch_count': 1}) mlp_log.mlperf_print('block_start', None) history = model.fit(train_input_dataset, epochs=1, steps_per_epoch=steps_per_epoch, callbacks=callbacks, verbose=2) mlp_log.mlperf_print('block_stop', None) eval_output = None if not flags_obj.skip_eval: mlp_log.mlperf_print('eval_start', None) eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) mlp_log.mlperf_print('eval_stop', None) eval_accuracy = eval_output[1] mlp_log.mlperf_print( 'eval_accuracy', eval_accuracy, metadata={'epoch_num': epoch}) if eval_accuracy >= flags_obj.target_accuracy: break mlp_log.mlperf_print('epoch_stop', None) mlp_log.mlperf_print('run_stop', None) if flags_obj.pruning_method: model = tfmot.sparsity.keras.strip_pruning(model) if flags_obj.enable_checkpoint_and_export: if dtype == tf.bfloat16: logging.warning('Keras model.save does not support bfloat16 dtype.') else: # Keras model.save assumes a float32 input designature. export_path = os.path.join(flags_obj.model_dir, 'saved_model') model.save(export_path, include_optimizer=False) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ mlp_log.mlperf_print('cache_clear', True) mlp_log.mlperf_print('init_start', None) mlp_log.mlperf_print('submission_benchmark', 'resnet') mlp_log.mlperf_print('submission_division', 'closed') mlp_log.mlperf_print('submission_org', 'google') mlp_log.mlperf_print( 'submission_platform', 'tpu-v3-{}'.format(flags_obj.num_replicas) if flags_obj.tpu else 'gpu-v100-{}'.format(flags_obj.num_gpus)) mlp_log.mlperf_print('submission_status', 'cloud') common.print_flags(flags_obj) keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) if tf.config.list_physical_devices('GPU'): if flags_obj.tf_gpu_thread_mode: datasets_num_private_threads = keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus) if not flags_obj.datasets_num_private_threads: flags_obj.datasets_num_private_threads = datasets_num_private_threads common.set_cudnn_batchnorm_mode() # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu, tpu_zone=flags_obj.tpu_zone if flags_obj.tpu else None) mlp_log.mlperf_print('global_batch_size', flags_obj.batch_size) mlp_log.mlperf_print('train_samples', imagenet_preprocessing.NUM_IMAGES['train']) mlp_log.mlperf_print('eval_samples', imagenet_preprocessing.NUM_IMAGES['validation']) mlp_log.mlperf_print( 'model_bn_span', int(flags_obj.batch_size / (flags_obj.num_replicas if flags_obj.tpu else flags_obj.num_gpus))) per_epoch_steps, train_epochs = common.get_num_train_iterations(flags_obj) eval_steps = common.get_num_eval_steps(flags_obj) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) logging.info( 'Training %d epochs, each epoch has %d steps, ' 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory( flags_obj.batch_size, flags_obj.log_steps, logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None) with distribution_utils.get_strategy_scope(strategy): runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback) eval_interval = (flags_obj.epochs_between_evals * per_epoch_steps if not flags_obj.skip_eval else None) eval_offset = (flags_obj.eval_offset_epochs * per_epoch_steps if not flags_obj.skip_eval else 0) if eval_offset != 0: eval_offset -= eval_interval checkpoint_interval = (per_epoch_steps if flags_obj.enable_checkpoint_and_export else None) summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=flags_obj.model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval) device_warmup_steps = (flags_obj.device_warmup_steps if flags_obj.enable_device_warmup else 0) if flags_obj.enable_device_warmup: logging.info('Warmup for %d steps.', device_warmup_steps) resnet_controller = controller.Controller( strategy, runnable.train, runnable.evaluate, runnable.warmup, global_step=runnable.global_step, steps_per_loop=steps_per_loop, train_steps=per_epoch_steps * train_epochs, device_warmup_steps=device_warmup_steps, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, eval_steps=eval_steps, eval_interval=eval_interval, eval_offset=eval_offset) if flags_obj.enable_device_warmup: resnet_controller.warmup() mlp_log.mlperf_print('init_stop', None) profile_steps = flags_obj.profile_steps if profile_steps: profile_steps = [int(i) for i in profile_steps.split(',')] if profile_steps[0] < 0: runnable.trace_start(-1) time_callback.on_train_begin() mlp_log.mlperf_print('run_start', None) mlp_log.mlperf_print( 'block_start', None, metadata={ 'first_epoch_num': 1, 'epoch_count': (flags_obj.eval_offset_epochs if flags_obj.eval_offset_epochs != 0 else flags_obj.epochs_between_evals) }) resnet_controller.train(evaluate=not flags_obj.skip_eval) mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) time_callback.on_train_end() mlp_log.mlperf_print('run_final', None) stats = build_stats(runnable, time_callback) return stats
def run_customized_training_loop( # pylint: disable=invalid-name _sentinel=None, # pylint: enable=invalid-name strategy=None, model_fn=None, loss_fn=None, model_dir=None, train_input_fn=None, steps_per_epoch=None, steps_per_loop=1, epochs=1, eval_input_fn=None, eval_steps=None, steps_between_eval=None, steps_before_eval_start=None, stop_threshold=None, metric_fn=None, init_checkpoint=None, custom_callbacks=None, run_eagerly=False, sub_model_export_name=None, explicit_allreduce=False, device_warmup=False, synthetic_train_input_fn=None, pre_allreduce_callbacks=None, post_allreduce_callbacks=None, allreduce_bytes_per_pack=0, enable_checkpoint_and_summary=False, num_accumulation_steps=1, stop_steps=None): """Run BERT pretrain model training using low-level API. Arguments: _sentinel: Used to prevent positional parameters. Internal, do not use. strategy: Distribution strategy on which to run low level training loop. model_fn: Function that returns a tuple (model, sub_model). Caller of this function should add optimizer to the `model` via calling `model.compile()` API or manually setting `model.optimizer` attribute. Second element of the returned tuple(sub_model) is an optional sub model to be used for initial checkpoint -- if provided. loss_fn: Function with signature func(labels, logits) and returns a loss tensor. model_dir: Model directory used during training for restoring/saving model weights. train_input_fn: Function that returns a tf.data.Dataset used for training. steps_per_epoch: Number of steps to run per epoch. At the end of each epoch, model checkpoint will be saved and evaluation will be conducted if evaluation dataset is provided. steps_per_loop: Number of steps per graph-mode loop. In order to reduce communication in eager context, training logs are printed every steps_per_loop. epochs: Number of epochs to train. eval_input_fn: Function that returns evaluation dataset. If none, evaluation is skipped. eval_steps: Number of steps to run evaluation. Required if `eval_input_fn` is not none. steps_between_eval: Number of steps between evals steps_before_eval_start: Number of steps to skip before starting eval stop_threshold: Stop threshold for MLPerf once accuracy achieved metric_fn: A metrics function that returns a Keras Metric object to record evaluation result using evaluation dataset or with training dataset after every epoch. init_checkpoint: Optional checkpoint to load to `sub_model` returned by `model_fn`. custom_callbacks: A list of Keras Callbacks objects to run during training. More specifically, `on_batch_begin()`, `on_batch_end()`, methods are invoked during training. run_eagerly: Whether to run model training in pure eager execution. This should be disable for TPUStrategy. sub_model_export_name: If not None, will export `sub_model` returned by `model_fn` into checkpoint files. The name of intermediate checkpoint file is {sub_model_export_name}_step_{step}.ckpt and the last checkpint's name is {sub_model_export_name}.ckpt; if None, `sub_model` will not be exported as checkpoint. explicit_allreduce: Whether to explicitly perform gradient allreduce, instead of relying on implicit allreduce in optimizer.apply_gradients(). default is False. For now, if training using FP16 mixed precision, explicit allreduce will aggregate gradients in FP16 format. For TPU and GPU training using FP32, explicit allreduce will aggregate gradients in FP32 format. device_warmup: Whether or not to enable device warmup. This runs the training and eval loop on synthetic data to pre-compile XLA and TF tracing before accessing data. synthetic_train_input_fn: Function that returns synthetic training dataset. This is used in device warmup. pre_allreduce_callbacks: A list of callback functions that takes gradients and model variables pairs as input, manipulate them, and returns a new gradients and model variables paris. The callback functions will be invoked in the list order and before gradients are allreduced. Default is no callbacks. Only used when explicit_allreduce=True. post_allreduce_callbacks: A list of callback functions that takes gradients and model variables pairs as input, manipulate them, and returns a new gradients and model variables paris. The callback functions will be invoked in the list order and right before gradients are applied to variables for updates. Default is no callbacks. Only used when explicit_allreduce=True. allreduce_bytes_per_pack: A non-negative integer. Breaks collective operations into packs of certain size. If it's zero, all gradients are in one pack. enable_checkpoint_and_summary: Whether to save checkpoint and summary. stop_steps: The number of steps to run before stopping the training loop. Returns: Trained model. Raises: ValueError: (1) When model returned by `model_fn` does not have optimizer attribute or when required parameters are set to none. (2) eval args are not specified correctly. (3) metric_fn must be a callable if specified. (4) sub_model_checkpoint_name is specified, but `sub_model` returned by `model_fn` is None. """ if _sentinel is not None: raise ValueError('only call `run_customized_training_loop()` ' 'with named arguments.') required_arguments = [ strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn ] if [arg for arg in required_arguments if arg is None]: raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, ' '`steps_per_loop` and `steps_per_epoch` are required ' 'parameters.') if steps_between_eval % steps_per_loop != 0: raise ValueError('steps_between_eval should be multiple of steps_per_loop.') if steps_per_loop > steps_per_epoch: logging.error( 'steps_per_loop: %d is specified to be greater than ' ' steps_per_epoch: %d, we will use steps_per_epoch as' ' steps_per_loop.', steps_per_loop, steps_per_epoch) steps_per_loop = steps_per_epoch assert tf.executing_eagerly() if run_eagerly: if steps_per_loop > 1: raise ValueError( 'steps_per_loop is used for performance optimization. When you want ' 'to run eagerly, you cannot leverage graph mode loop.') if isinstance(strategy, tf.distribute.experimental.TPUStrategy): raise ValueError( 'TPUStrategy should not run eagerly as it heavily replies on graph' ' optimization for the distributed system.') if eval_input_fn and (eval_steps is None): raise ValueError( '`eval_step` and `metric_fn` are required when `eval_input_fn ` ' 'is not none.') if device_warmup and (synthetic_train_input_fn is None): raise ValueError('`synthetic_train_input_fn` is required when ' 'device_warmup is enabled.') if metric_fn and not callable(metric_fn): raise ValueError( 'if `metric_fn` is specified, metric_fn must be a callable.') if stop_steps: total_training_steps = stop_steps else: total_training_steps = steps_per_epoch * epochs if stop_steps and stop_steps > steps_per_epoch * epochs: raise ValueError('`stop_steps` should not be greater than ' '`num_train_steps_per_epoch` * `num_epochs`.') # To reduce unnecessary send/receive input pipeline operation, we place input # pipeline ops in worker task. train_iterator = _get_input_iterator(train_input_fn, strategy) with distribution_utils.get_strategy_scope(strategy): # To correctly place the model weights on accelerators, # model and optimizer should be created in scope. model, sub_model, sub_pretrain_model = model_fn() if not hasattr(model, 'optimizer'): raise ValueError('User should set optimizer attribute to model ' 'inside `model_fn`.') if sub_model_export_name and sub_model is None: raise ValueError('sub_model_export_name is specified as %s, but ' 'sub_model is None.' % sub_model_export_name) optimizer = model.optimizer train_loss_metric = tf.keras.metrics.Mean( 'training_loss', dtype=tf.float32) if eval_input_fn: eval_metric_num = tf.keras.metrics.Sum('masked_lm_num', dtype=tf.float32) eval_metric_denom = tf.keras.metrics.Sum( 'masked_lm_denom', dtype=tf.float32) # If evaluation is required, make a copy of metric as it will be used by # both train and evaluation. train_metrics = [ tf.keras.metrics.Mean('masked_lm_accuracy', dtype=tf.float32) ] # Create summary writers summary_dir = os.path.join(model_dir, 'summaries') if enable_checkpoint_and_summary: eval_summary_writer = tf.summary.create_file_writer( os.path.join(summary_dir, 'eval')) else: eval_summary_writer = tf.summary.create_noop_writer() if steps_per_loop >= _MIN_SUMMARY_STEPS and enable_checkpoint_and_summary: # Only writes summary when the stats are collected sufficiently over # enough steps. train_summary_writer = tf.summary.create_file_writer( os.path.join(summary_dir, 'train')) else: train_summary_writer = tf.summary.create_noop_writer() # Collects training variables. training_vars = model.trainable_variables @tf.function(experimental_compile=True) def _compiled_local_step(inputs, labels, training_vars, accum_vars): """Replicated training step.""" with tf.GradientTape() as tape: model_outputs, metric_outputs = model(inputs, training=True) loss = loss_fn(labels, model_outputs) if isinstance(optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer): with tape: scaled_loss = optimizer.get_scaled_loss(loss) scaled_grads = tape.gradient(scaled_loss, training_vars) grads = optimizer.get_unscaled_gradients(scaled_grads) else: grads = tape.gradient(loss, training_vars) (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) if accum_vars is None: return grads, loss, model_outputs, metric_outputs else: new_accum_vars = [] for i, grad in enumerate(grads): new_accum_vars.append( accum_vars[i] + tf.math.scalar_mul(1.0 / num_accumulation_steps, grad)) return new_accum_vars, loss, model_outputs, metric_outputs def get_input_slice(input_dict, idx): split_input = {} for key in input_dict: split_input[key] = input_dict[key][idx] return split_input def _replicated_step(inputs): """Replicated training step.""" inputs, labels = inputs if explicit_allreduce: # TODO(b/155523821): Fix OOM issue so we use experimental_compile with # multi-worker mirrored strategy. with tf.GradientTape() as tape: model_outputs, metric_outputs = model(inputs, training=True) loss = loss_fn(labels, model_outputs) grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss, training_vars, pre_allreduce_callbacks, post_allreduce_callbacks, allreduce_bytes_per_pack) else: if num_accumulation_steps > 1: accum_vars = [ tf.zeros_like(tvar, dtype=tf.float32) for tvar in training_vars ] for key in inputs: inputs[key] = tf.split(inputs[key], num_accumulation_steps) split_labels = tf.split(labels, num_accumulation_steps) for local_step in range(num_accumulation_steps): accum_vars, loss, model_outputs, metric_outputs = _compiled_local_step( get_input_slice(inputs, local_step), split_labels[local_step], training_vars, accum_vars) optimizer.apply_gradients(zip(accum_vars, training_vars)) else: grads, loss, model_outputs, metric_outputs = _compiled_local_step( inputs, labels, training_vars, None) optimizer.apply_gradients(zip(grads, training_vars)) # For reporting, the metric takes the mean of losses. train_loss_metric.update_state(loss) for metric in train_metrics: metric.update_state(metric_outputs['masked_lm_accuracy']) @tf.function def train_steps(iterator, steps): """Performs distributed training steps in a loop. Args: iterator: the distributed iterator of training datasets. steps: an tf.int32 integer tensor to specify number of steps to run inside host training loop. Raises: ValueError: Any of the arguments or tensor shapes are invalid. """ if not isinstance(steps, tf.Tensor): raise ValueError('steps should be an Tensor. Python object may cause ' 'retracing.') for _ in tf.range(steps): strategy.run(_replicated_step, args=(next(iterator),)) def train_single_step(iterator): """Performs a distributed training step. Args: iterator: the distributed iterator of training datasets. Raises: ValueError: Any of the arguments or tensor shapes are invalid. """ strategy.run(_replicated_step, args=(next(iterator),)) def test_step(iterator): """Calculates evaluation metrics on distributed devices.""" def _test_step_fn(inputs): """Replicated accuracy calculation.""" inputs, labels = inputs model_outputs, metric_outputs = model(inputs, training=False) eval_metric_num.update_state(metric_outputs['masked_lm_num']) eval_metric_denom.update_state(metric_outputs['masked_lm_denom']) strategy.run(_test_step_fn, args=(next(iterator),)) if not run_eagerly: train_single_step = tf.function(train_single_step) test_step = tf.function(test_step) def _run_evaluation(current_training_step, test_iterator): """Runs validation steps and aggregate metrics.""" mlperf_epoch_num = int(current_training_step / steps_between_eval) mlp_log.mlperf_print( 'eval_start', None, metadata={'epoch_num': mlperf_epoch_num}) for _ in range(eval_steps): test_step(test_iterator) mlp_log.mlperf_print( 'eval_stop', None, metadata={'epoch_num': mlperf_epoch_num}) with eval_summary_writer.as_default(): masked_lm_accuracy = ( _float_metric_value(eval_metric_num) / _float_metric_value(eval_metric_denom)) logging.info('Step: [%d] Validation %s = %f', current_training_step, 'masked_lm_accuracy', masked_lm_accuracy) tf.summary.scalar( 'masked_lm_accuracy', masked_lm_accuracy, step=current_training_step) mlp_log.mlperf_print( 'eval_accuracy', masked_lm_accuracy, metadata={'epoch_num': mlperf_epoch_num}) eval_summary_writer.flush() return masked_lm_accuracy def _run_callbacks_on_batch_begin(batch): """Runs custom callbacks at the start of every step.""" # While BERT pretraining does not have epochs, # to make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps. mlp_log.mlperf_print( 'block_start', None, metadata={ 'first_epoch_num': int(batch), 'epoch_count': int(steps_per_loop), }) if not custom_callbacks: return for callback in custom_callbacks: callback.on_batch_begin(batch) def _run_callbacks_on_batch_end(batch, logs): """Runs custom callbacks at the end of every step.""" mlp_log.mlperf_print( 'block_stop', None, metadata={ 'first_epoch_num': int(batch), }) if not custom_callbacks: return for callback in custom_callbacks: callback.on_batch_end(batch, logs) # Training loop starts here. checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) sub_model_checkpoint = tf.train.Checkpoint( model=sub_model) if sub_model_export_name else None # TODO: commenting this out, as we always load from a initial checkpoint # latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) # if latest_checkpoint_file: # logging.info( # 'Checkpoint file %s found and restoring from ' # 'checkpoint', latest_checkpoint_file) # checkpoint.restore(latest_checkpoint_file) # logging.info('Loading from checkpoint file completed') current_step = optimizer.iterations.numpy() checkpoint_name = 'ctl_step_{step}.ckpt' checkpoint_save_dir = model_dir if enable_checkpoint_and_summary else None if init_checkpoint: logging.info( 'Checkpoint file %s found and restoring from ' 'initial checkpoint for core model.', init_checkpoint) checkpoint = tf.train.Checkpoint(model=sub_pretrain_model) checkpoint.restore(init_checkpoint).assert_existing_objects_matched() logging.info('Loading from checkpoint file completed') if device_warmup: synthetic_train_iterator = _get_input_iterator(synthetic_train_input_fn, strategy) logging.info('Running device warmup for 1 step.') train_steps(synthetic_train_iterator, tf.constant(1, dtype=tf.int32)) # Reset the global step. tf.keras.backend.set_value(optimizer.iterations, 0) current_step = optimizer.iterations.numpy() masked_lm_accuracy = 0 mlp_log.mlperf_print('init_stop', None) mlp_log.mlperf_print('run_start', None) while current_step < total_training_steps: # Training loss/metric are taking average over steps inside micro # training loop. We reset the their values before each round. train_loss_metric.reset_states() for metric in train_metrics + model.metrics: metric.reset_states() _run_callbacks_on_batch_begin(current_step) # Runs several steps in the host while loop. steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop) train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32)) train_loss = _float_metric_value(train_loss_metric) _run_callbacks_on_batch_end(current_step, {'loss': train_loss}) current_step += steps # Updates training logging. training_status = 'Train Step: %d/%d / loss = %s' % ( current_step, total_training_steps, train_loss) with train_summary_writer.as_default(): tf.summary.scalar( train_loss_metric.name, train_loss, step=current_step) for metric in train_metrics + model.metrics: metric_value = _float_metric_value(metric) training_status += ' %s = %f' % (metric.name, metric_value) tf.summary.scalar(metric.name, metric_value, step=current_step) train_summary_writer.flush() logging.info(training_status) # Saves model checkpoints and run validation steps at every epoch end. if current_step % steps_per_epoch == 0: # To avoid repeated model saving, we do not save after the last # step of training. if current_step < total_training_steps: _save_checkpoint(checkpoint, checkpoint_save_dir, checkpoint_name.format(step=current_step)) if sub_model_export_name: _save_checkpoint( sub_model_checkpoint, checkpoint_save_dir, '%s_step_%d.ckpt' % (sub_model_export_name, current_step)) if eval_input_fn and (current_step % (steps_between_eval) == 0) and ( current_step >= steps_before_eval_start): logging.info('Running evaluation after step: %s.', current_step) masked_lm_accuracy = _run_evaluation( current_step, _get_input_iterator(eval_input_fn, strategy)) if masked_lm_accuracy >= stop_threshold: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) break # Re-initialize evaluation metric. eval_metric_num.reset_states() eval_metric_denom.reset_states() if masked_lm_accuracy < stop_threshold: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'aborted'}) _save_checkpoint(checkpoint, checkpoint_save_dir, checkpoint_name.format(step=current_step)) if sub_model_export_name: _save_checkpoint(sub_model_checkpoint, checkpoint_save_dir, '%s.ckpt' % sub_model_export_name) if enable_checkpoint_and_summary: training_summary = { 'total_training_steps': total_training_steps, 'train_loss': _float_metric_value(train_loss_metric), } if train_metrics: # TODO(hongkuny): Cleans up summary reporting in text. training_summary['last_train_metrics'] = _float_metric_value( train_metrics[0]) #training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0]) write_txt_summary(training_summary, summary_dir) return model, masked_lm_accuracy, current_step
def __init__(self, batch_size, steps_per_epoch, train_steps, initial_learning_rate=None, end_learning_rate=None, warmup_epochs=None, compute_lr_on_cpu=False, name=None): """Applies a polynomial decay to the learning rate with warmup.""" super(PolynomialDecayWithWarmup, self).__init__() self.batch_size = batch_size self.steps_per_epoch = steps_per_epoch self.train_steps = train_steps self.name = name self.learning_rate_ops_cache = {} self.compute_lr_on_cpu = compute_lr_on_cpu if batch_size < 16384: self.initial_learning_rate = 10.0 warmup_epochs_ = 5 elif batch_size < 32768: self.initial_learning_rate = 25.0 warmup_epochs_ = 5 else: self.initial_learning_rate = 31.2 warmup_epochs_ = 25 # Override default poly learning rate and warmup epochs if initial_learning_rate: self.initial_learning_rate = initial_learning_rate if end_learning_rate: self.end_learning_rate = end_learning_rate else: self.end_learning_rate = 0.0001 if warmup_epochs is not None: warmup_epochs_ = warmup_epochs self.warmup_epochs = warmup_epochs_ opt_name = FLAGS.optimizer.lower() mlp_log.mlperf_print('opt_name', opt_name) if opt_name == 'lars': mlp_log.mlperf_print('{}_epsilon'.format(opt_name), FLAGS.lars_epsilon) mlp_log.mlperf_print('{}_opt_weight_decay'.format(opt_name), FLAGS.weight_decay) mlp_log.mlperf_print('{}_opt_base_learning_rate'.format(opt_name), self.initial_learning_rate) mlp_log.mlperf_print( '{}_opt_learning_rate_warmup_epochs'.format(opt_name), warmup_epochs_) mlp_log.mlperf_print('{}_opt_end_learning_rate'.format(opt_name), self.end_learning_rate) warmup_steps = warmup_epochs_ * steps_per_epoch self.warmup_steps = tf.cast(warmup_steps, tf.float32) self.decay_steps = train_steps - warmup_steps + 1 mlp_log.mlperf_print( '{}_opt_learning_rate_decay_steps'.format(opt_name), int(self.decay_steps)) mlp_log.mlperf_print( '{}_opt_learning_rate_decay_poly_power'.format(opt_name), 2.0) mlp_log.mlperf_print('{}_opt_momentum'.format(opt_name), FLAGS.momentum) self.poly_rate_scheduler = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=self.initial_learning_rate, decay_steps=self.decay_steps, end_learning_rate=self.end_learning_rate, power=2.0)