def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. NotImplementedError: If some features are not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) performance.set_mixed_precision_policy( flags_core.get_tf_dtype(flags_obj), flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla # Current resnet_model.resnet50 input format is always channel-last. # We use keras_application mobilenet model which input format is depends on # the keras beckend image data format. # This use_keras_image_data_format flags indicates whether image preprocessor # output format should be same as the keras backend image data format or just # channel-last format. use_keras_image_data_format = (flags_obj.model == 'mobilenet') train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, training_dataset_cache=flags_obj.training_dataset_cache, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), dtype=dtype, drop_remainder=drop_remainder) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) with strategy_scope: if flags_obj.optimizer == 'resnet50_default': optimizer = common.get_optimizer(lr_schedule) elif flags_obj.optimizer == 'mobilenet_default': initial_learning_rate = \ flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size optimizer = tf.keras.optimizers.SGD( learning_rate=tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay, decay_rate=flags_obj.lr_decay_factor, staircase=True), momentum=0.9) if flags_obj.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) # TODO(hongkuny): Remove trivial model usage and move it to benchmark. if flags_obj.use_trivial_model: model = test_utils.trivial_model( imagenet_preprocessing.NUM_CLASSES) elif flags_obj.model == 'resnet50_v1.5': model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES) elif flags_obj.model == 'mobilenet': # TODO(kimjaehong): Remove layers attribute when minimum TF version # support 2.0 layers by default. model = tf.keras.applications.mobilenet.MobileNet( weights=None, classes=imagenet_preprocessing.NUM_CLASSES, layers=tf.keras.layers) if flags_obj.pretrained_filepath: model.load_weights(flags_obj.pretrained_filepath) if flags_obj.pruning_method == 'polynomial_decay': if dtype != tf.float32: raise NotImplementedError( 'Pruning is currently only supported on dtype=tf.float32.') pruning_params = { 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay( initial_sparsity=flags_obj.pruning_initial_sparsity, final_sparsity=flags_obj.pruning_final_sparsity, begin_step=flags_obj.pruning_begin_step, end_step=flags_obj.pruning_end_step, frequency=flags_obj.pruning_frequency), } model = tfmot.sparsity.keras.prune_low_magnitude( model, **pruning_params) elif flags_obj.pruning_method: raise NotImplementedError( 'Only polynomial_decay is currently supported.') model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks( steps_per_epoch=steps_per_epoch, pruning_method=flags_obj.pruning_method, enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export, model_dir=flags_obj.model_dir) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if flags_obj.pruning_method: model = tfmot.sparsity.keras.strip_pruning(model) if flags_obj.enable_checkpoint_and_export: if dtype == tf.bfloat16: logging.warning( 'Keras model.save does not support bfloat16 dtype.') else: # Keras model.save assumes a float32 input designature. export_path = os.path.join(flags_obj.model_dir, 'saved_model') model.save(export_path, include_optimizer=False) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ tf.get_logger().propagate = False output_dir = None if "LOG_DIR" in os.environ: output_dir = os.environ["LOG_DIR"] mlperf_mlloger, mlperf_mllog = get_mllog_mlloger(output_dir) mlperf_mlloger.event(key=mlperf_mllog.constants.CACHE_CLEAR, value=True) mlperf_mlloger.start(key=mlperf_mllog.constants.INIT_START, value=None) mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_BENCHMARK, value=mlperf_mllog.constants.RESNET) mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_ORG, value='Habana') mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_DIVISION, value='closed') mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_PLATFORM, value='gaudi-{}'.format(flags_obj.num_gpus)) mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_STATUS, value='onprem') keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) # This only affects GPU. common.set_cudnn_batchnorm_mode() # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) if horovod_enabled(): model_dir = os.path.join(flags_obj.model_dir, "worker_" + str(hvd.rank())) else: model_dir = flags_obj.model_dir global_batch_size = get_global_batch_size(flags_obj.batch_size) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) mlperf_mlloger.event(key=mlperf_mllog.constants.GLOBAL_BATCH_SIZE, value=global_batch_size) mlperf_mlloger.event(key=mlperf_mllog.constants.TRAIN_SAMPLES, value=imagenet_preprocessing.NUM_IMAGES['train']) mlperf_mlloger.event(key=mlperf_mllog.constants.EVAL_SAMPLES, value=imagenet_preprocessing.NUM_IMAGES['validation']) group_batch_norm = 1 mlperf_mlloger.event(key=mlperf_mllog.constants.MODEL_BN_SPAN, value=flags_obj.batch_size * group_batch_norm) train_writer, eval_writer = None, None if flags_obj.enable_tensorboard: train_writer = tf.summary.create_file_writer(model_dir) eval_writer = tf.summary.create_file_writer( os.path.join(model_dir, 'eval')) hparams = flags_obj.flag_values_dict() write_hparams_v2(train_writer, hparams) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) train_steps = train_epochs * per_epoch_steps logging.info( 'Training %d epochs, each epoch has %d steps, ' 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, train_steps, eval_steps) time_callback = keras_utils.TimeHistory( global_batch_size, flags_obj.log_steps, summary_writer=train_writer, batch_size_per_node=flags_obj.batch_size) profiler_callback = None if flags_obj.profile_steps is not None: profiler_callback = keras_utils.get_profiler_callback( model_dir, flags_obj.profile_steps, flags_obj.enable_tensorboard, per_epoch_steps) with distribution_utils.get_strategy_scope(strategy): runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, train_steps, per_epoch_steps, profiler_callback, mlperf_mlloger, mlperf_mllog) eval_interval = flags_obj.epochs_between_evals * per_epoch_steps eval_offset = flags_obj.eval_offset_epochs * per_epoch_steps if eval_offset != 0: eval_offset -= eval_interval checkpoint_interval = (per_epoch_steps if flags_obj.enable_checkpoint_and_export else None) summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval) device_warmup_steps = (flags_obj.device_warmup_steps if flags_obj.enable_device_warmup else 0) if flags_obj.enable_device_warmup: logging.info('Warmup for %d steps.', device_warmup_steps) train_steps = per_epoch_steps * train_epochs resnet_controller = controller.Controller( strategy, runnable.train, runnable.evaluate, runnable.warmup, global_step=runnable.global_step, steps_per_loop=steps_per_loop, train_steps=train_steps, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, eval_steps=eval_steps, eval_interval=eval_interval, eval_offset=eval_offset, device_warmup_steps=device_warmup_steps, train_summary_writer=train_writer, eval_summary_writer=eval_writer) if flags_obj.enable_device_warmup: resnet_controller.warmup() mlperf_mlloger.end(key=mlperf_mllog.constants.INIT_STOP) hvd.broadcast(0, 0) time_callback.on_train_begin() mlperf_mlloger.start(key=mlperf_mllog.constants.RUN_START) mlperf_mlloger.start( key=mlperf_mllog.constants.BLOCK_START, value=None, metadata={ 'first_epoch_num': 1, 'epoch_count': (flags_obj.eval_offset_epochs if flags_obj.eval_offset_epochs > 0 else flags_obj.epochs_between_evals) }) resnet_controller.train(evaluate=not flags_obj.skip_eval) if not flags_obj.skip_eval: eval_accuracy = resnet_controller.last_eval_output['test_accuracy'] if eval_accuracy >= flags_obj.target_accuracy: mlperf_mlloger.end(key=mlperf_mllog.constants.RUN_STOP, value=None, metadata={'status': 'success'}) else: mlperf_mlloger.end(key=mlperf_mllog.constants.RUN_STOP, value=None, metadata={'status': 'fail'}) time_callback.on_train_end() stats = build_stats(runnable, time_callback) return stats
def train_and_eval( params: base_configs.ExperimentConfig, strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]: """Runs the train and eval path using compile/fit.""" logging.info('Running train and eval.') # Note: for TPUs, strategy and scope should be created before the dataset strategy = strategy_override or distribution_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) strategy_scope = distribution_utils.get_strategy_scope(strategy) logging.info('Detected %d devices.', strategy.num_replicas_in_sync if strategy else 1) label_smoothing = params.model.loss.label_smoothing one_hot = label_smoothing and label_smoothing > 0 builders = _get_dataset_builders(params, strategy, one_hot) datasets = [builder.build() if builder else None for builder in builders] # Unpack datasets and builders based on train/val/test splits train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking train_dataset, validation_dataset = datasets train_epochs = params.train.epochs train_steps = params.train.steps or train_builder.num_steps validation_steps = params.evaluation.steps or validation_builder.num_steps initialize(params, train_builder) logging.info('Global batch size: %d', train_builder.global_batch_size) with strategy_scope: model_params = params.model.model_params.as_dict() model = get_models()[params.model.name](**model_params) learning_rate = optimizer_factory.build_learning_rate( params=params.model.learning_rate, batch_size=train_builder.global_batch_size, train_steps=train_steps) optimizer = optimizer_factory.build_optimizer( optimizer_name=params.model.optimizer.name, base_learning_rate=learning_rate, params=params.model.optimizer.as_dict()) metrics_map = _get_metrics(one_hot) metrics = [metrics_map[metric] for metric in params.train.metrics] if one_hot: loss_obj = tf.keras.losses.CategoricalCrossentropy( label_smoothing=params.model.loss.label_smoothing) else: loss_obj = tf.keras.losses.SparseCategoricalCrossentropy() model.compile(optimizer=optimizer, loss=loss_obj, metrics=metrics) initial_epoch = 0 if params.train.resume_checkpoint: initial_epoch = resume_from_checkpoint(model=model, model_dir=params.model_dir, train_steps=train_steps) serialize_config(params=params, model_dir=params.model_dir) # TODO(dankondratyuk): callbacks significantly slow down training callbacks = custom_callbacks.get_callbacks( model_checkpoint=params.train.callbacks.enable_checkpoint_and_export, include_tensorboard=params.train.callbacks.enable_tensorboard, time_history=params.train.callbacks.enable_time_history, track_lr=params.train.tensorboard.track_lr, write_model_weights=params.train.tensorboard.write_model_weights, initial_step=initial_epoch * train_steps, batch_size=train_builder.global_batch_size, log_steps=params.train.time_history.log_steps, model_dir=params.model_dir) if params.evaluation.skip_eval: validation_kwargs = {} else: validation_kwargs = { 'validation_data': validation_dataset, 'validation_steps': validation_steps, 'validation_freq': params.evaluation.epochs_between_evals, } history = model.fit(train_dataset, epochs=train_epochs, steps_per_epoch=train_steps, initial_epoch=initial_epoch, callbacks=callbacks, **validation_kwargs) validation_output = None if not params.evaluation.skip_eval: validation_output = model.evaluate(validation_dataset, steps=validation_steps, verbose=2) # TODO(dankondratyuk): eval and save final test accuracy stats = common.build_stats(history, validation_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) # This only affects GPU. common.set_cudnn_batchnorm_mode() # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) logging.info( 'Training %d epochs, each epoch has %d steps, ' 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory( flags_obj.batch_size, flags_obj.log_steps, logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None) profiler_callback = None if flags_obj.profile_steps is not None: profiler_callback = keras_utils.get_profiler_callback( flags_obj.model_dir, flags_obj.profile_steps, flags_obj.enable_tensorboard, per_epoch_steps) with distribution_utils.get_strategy_scope(strategy): runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, per_epoch_steps, profiler_callback) eval_interval = flags_obj.epochs_between_evals * per_epoch_steps checkpoint_interval = ( per_epoch_steps if flags_obj.enable_checkpoint_and_export else None) summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=flags_obj.model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval) resnet_controller = controller.Controller( strategy, runnable.train, runnable.evaluate, global_step=runnable.global_step, steps_per_loop=steps_per_loop, train_steps=per_epoch_steps * train_epochs, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, eval_steps=eval_steps, eval_interval=eval_interval) time_callback.on_train_begin() resnet_controller.train(evaluate=not flags_obj.skip_eval) time_callback.on_train_end() stats = build_stats(runnable, time_callback) return stats
def run_customized_training_loop( # pylint: disable=invalid-name _sentinel=None, # pylint: enable=invalid-name strategy=None, model_fn=None, loss_fn=None, scale_loss=True, model_dir=None, train_input_fn=None, steps_per_epoch=None, steps_per_loop=1, epochs=1, eval_input_fn=None, eval_steps=None, metric_fn=None, init_checkpoint=None, custom_callbacks=None, run_eagerly=False, sub_model_export_name=None, explicit_allreduce=False, pre_allreduce_callbacks=None, post_allreduce_callbacks=None): """Run BERT pretrain model training using low-level API. Arguments: _sentinel: Used to prevent positional parameters. Internal, do not use. strategy: Distribution strategy on which to run low level training loop. model_fn: Function that returns a tuple (model, sub_model). Caller of this function should add optimizer to the `model` via calling `model.compile()` API or manually setting `model.optimizer` attribute. Second element of the returned tuple(sub_model) is an optional sub model to be used for initial checkpoint -- if provided. loss_fn: Function with signature func(labels, logits) and returns a loss tensor. scale_loss: Whether to divide the raw loss by number of replicas before gradients calculation. model_dir: Model directory used during training for restoring/saving model weights. train_input_fn: Function that returns a tf.data.Dataset used for training. steps_per_epoch: Number of steps to run per epoch. At the end of each epoch, model checkpoint will be saved and evaluation will be conducted if evaluation dataset is provided. steps_per_loop: Number of steps per graph-mode loop. In order to reduce communication in eager context, training logs are printed every steps_per_loop. epochs: Number of epochs to train. eval_input_fn: Function that returns evaluation dataset. If none, evaluation is skipped. eval_steps: Number of steps to run evaluation. Required if `eval_input_fn` is not none. metric_fn: A metrics function that returns a Keras Metric object to record evaluation result using evaluation dataset or with training dataset after every epoch. init_checkpoint: Optional checkpoint to load to `sub_model` returned by `model_fn`. custom_callbacks: A list of Keras Callbacks objects to run during training. More specifically, `on_batch_begin()`, `on_batch_end()`, methods are invoked during training. run_eagerly: Whether to run model training in pure eager execution. This should be disable for TPUStrategy. sub_model_export_name: If not None, will export `sub_model` returned by `model_fn` into checkpoint files. The name of intermediate checkpoint file is {sub_model_export_name}_step_{step}.ckpt and the last checkpint's name is {sub_model_export_name}.ckpt; if None, `sub_model` will not be exported as checkpoint. explicit_allreduce: Whether to explicitly perform gradient allreduce, instead of relying on implicit allreduce in optimizer.apply_gradients(). default is False. For now, if training using FP16 mixed precision, explicit allreduce will aggregate gradients in FP16 format. For TPU and GPU training using FP32, explicit allreduce will aggregate gradients in FP32 format. pre_allreduce_callbacks: A list of callback functions that takes gradients and model variables pairs as input, manipulate them, and returns a new gradients and model variables paris. The callback functions will be invoked in the list order and before gradients are allreduced. With mixed precision training, the pre_allreduce_allbacks will be applied on scaled_gradients. Default is no callbacks. Only used when explicit_allreduce=True. post_allreduce_callbacks: A list of callback functions that takes gradients and model variables pairs as input, manipulate them, and returns a new gradients and model variables paris. The callback functions will be invoked in the list order and right before gradients are applied to variables for updates. Default is no callbacks. Only used when explicit_allreduce=True. Returns: Trained model. Raises: ValueError: (1) When model returned by `model_fn` does not have optimizer attribute or when required parameters are set to none. (2) eval args are not specified correctly. (3) metric_fn must be a callable if specified. (4) sub_model_checkpoint_name is specified, but `sub_model` returned by `model_fn` is None. """ if _sentinel is not None: raise ValueError('only call `run_customized_training_loop()` ' 'with named arguments.') required_arguments = [ strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn ] if [arg for arg in required_arguments if arg is None]: raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, ' '`steps_per_loop` and `steps_per_epoch` are required ' 'parameters.') if steps_per_loop > steps_per_epoch: logging.error( 'steps_per_loop: %d is specified to be greater than ' ' steps_per_epoch: %d, we will use steps_per_epoch as' ' steps_per_loop.', steps_per_loop, steps_per_epoch) steps_per_loop = steps_per_epoch assert tf.executing_eagerly() if run_eagerly: if isinstance(strategy, tf.distribute.experimental.TPUStrategy): raise ValueError( 'TPUStrategy should not run eagerly as it heavily relies on graph' ' optimization for the distributed system.') if eval_input_fn and (eval_steps is None or metric_fn is None): raise ValueError( '`eval_step` and `metric_fn` are required when `eval_input_fn ` ' 'is not none.') if metric_fn and not callable(metric_fn): raise ValueError( 'if `metric_fn` is specified, metric_fn must be a callable.') total_training_steps = steps_per_epoch * epochs train_iterator = _get_input_iterator(train_input_fn, strategy) with distribution_utils.get_strategy_scope(strategy): # To correctly place the model weights on accelerators, # model and optimizer should be created in scope. model, sub_model = model_fn() if not hasattr(model, 'optimizer'): raise ValueError('User should set optimizer attribute to model ' 'inside `model_fn`.') if sub_model_export_name and sub_model is None: raise ValueError('sub_model_export_name is specified as %s, but ' 'sub_model is None.' % sub_model_export_name) optimizer = model.optimizer if init_checkpoint: logging.info( 'Checkpoint file %s found and restoring from ' 'initial checkpoint for core model.', init_checkpoint) checkpoint = tf.train.Checkpoint(model=sub_model) checkpoint.restore(init_checkpoint).assert_existing_objects_matched() logging.info('Loading from checkpoint file completed') train_loss_metric = tf.keras.metrics.Mean( 'training_loss', dtype=tf.float32) eval_metrics = [metric_fn()] if metric_fn else [] # If evaluation is required, make a copy of metric as it will be used by # both train and evaluation. train_metrics = [ metric.__class__.from_config(metric.get_config()) for metric in eval_metrics ] # Create summary writers if _should_export_summary(strategy): summary_dir = os.path.join(model_dir, 'summaries') else: # In multi worker training we need every worker to write summary, because # variables can trigger synchronization on read and synchronization needs # all workers to participate. summary_dir = tempfile.mkdtemp() eval_summary_writer = tf.summary.create_file_writer( os.path.join(summary_dir, 'eval')) if steps_per_loop >= _MIN_SUMMARY_STEPS: # Only writes summary when the stats are collected sufficiently over # enough steps. train_summary_writer = tf.summary.create_file_writer( os.path.join(summary_dir, 'train')) else: train_summary_writer = None # Collects training variables. training_vars = model.trainable_variables def _replicated_step(inputs): """Replicated training step.""" inputs, labels = inputs with tf.GradientTape() as tape: model_outputs = model(inputs, training=True) loss = loss_fn(labels, model_outputs) # Raw loss is used for reporting in metrics/logs. raw_loss = loss if scale_loss: # Scales down the loss for gradients to be invariant from replicas. loss = loss / strategy.num_replicas_in_sync if explicit_allreduce: grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss, training_vars, pre_allreduce_callbacks, post_allreduce_callbacks) else: if isinstance(optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer): with tape: scaled_loss = optimizer.get_scaled_loss(loss) scaled_grads = tape.gradient(scaled_loss, training_vars) grads = optimizer.get_unscaled_gradients(scaled_grads) else: grads = tape.gradient(loss, training_vars) optimizer.apply_gradients(zip(grads, training_vars)) # For reporting, the metric takes the mean of losses. train_loss_metric.update_state(raw_loss) for metric in train_metrics: metric.update_state(labels, model_outputs) @tf.function def train_steps(iterator, steps): """Performs distributed training steps in a loop. Args: iterator: the distributed iterator of training datasets. steps: an tf.int32 integer tensor to specify number of steps to run inside host training loop. Raises: ValueError: Any of the arguments or tensor shapes are invalid. """ if not isinstance(steps, tf.Tensor): raise ValueError('steps should be an Tensor. Python object may cause ' 'retracing.') for _ in tf.range(steps): strategy.run(_replicated_step, args=(next(iterator),)) def train_single_step(iterator): """Performs a distributed training step. Args: iterator: the distributed iterator of training datasets. Raises: ValueError: Any of the arguments or tensor shapes are invalid. """ strategy.run(_replicated_step, args=(next(iterator),)) def test_step(iterator): """Calculates evaluation metrics on distributed devices.""" def _test_step_fn(inputs): """Replicated accuracy calculation.""" inputs, labels = inputs model_outputs = model(inputs, training=False) for metric in eval_metrics: metric.update_state(labels, model_outputs) strategy.run(_test_step_fn, args=(next(iterator),)) if not run_eagerly: train_single_step = tf.function(train_single_step) test_step = tf.function(test_step) def _run_evaluation(current_training_step, test_iterator): """Runs validation steps and aggregate metrics.""" for _ in range(eval_steps): test_step(test_iterator) with eval_summary_writer.as_default(): for metric in eval_metrics + model.metrics: metric_value = _float_metric_value(metric) logging.info('Step: [%d] Validation %s = %f', current_training_step, metric.name, metric_value) tf.summary.scalar( metric.name, metric_value, step=current_training_step) eval_summary_writer.flush() def _run_callbacks_on_batch_begin(batch): """Runs custom callbacks at the start of every step.""" if not custom_callbacks: return for callback in custom_callbacks: callback.on_batch_begin(batch) def _run_callbacks_on_batch_end(batch, logs): """Runs custom callbacks at the end of every step.""" if not custom_callbacks: return for callback in custom_callbacks: callback.on_batch_end(batch, logs) # Training loop starts here. checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) sub_model_checkpoint = tf.train.Checkpoint( model=sub_model) if sub_model_export_name else None latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) if latest_checkpoint_file: logging.info( 'Checkpoint file %s found and restoring from ' 'checkpoint', latest_checkpoint_file) checkpoint.restore(latest_checkpoint_file) logging.info('Loading from checkpoint file completed') current_step = optimizer.iterations.numpy() checkpoint_name = 'ctl_step_{step}.ckpt' while current_step < total_training_steps: # Training loss/metric are taking average over steps inside micro # training loop. We reset the their values before each round. train_loss_metric.reset_states() for metric in train_metrics + model.metrics: metric.reset_states() _run_callbacks_on_batch_begin(current_step) # Runs several steps in the host while loop. steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop) if tf.test.is_built_with_cuda(): # TODO(zongweiz): merge with train_steps once tf.while_loop # GPU performance bugs are fixed. for _ in range(steps): train_single_step(train_iterator) else: # Converts steps to a Tensor to avoid tf.function retracing. train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32)) train_loss = _float_metric_value(train_loss_metric) current_step += steps _run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss}) # Updates training logging. training_status = 'Train Step: %d/%d / loss = %s' % ( current_step, total_training_steps, train_loss) if train_summary_writer: with train_summary_writer.as_default(): tf.summary.scalar( train_loss_metric.name, train_loss, step=current_step) for metric in train_metrics + model.metrics: metric_value = _float_metric_value(metric) training_status += ' %s = %f' % (metric.name, metric_value) tf.summary.scalar(metric.name, metric_value, step=current_step) train_summary_writer.flush() logging.info(training_status) # Saves model checkpoints and run validation steps at every epoch end. if current_step % steps_per_epoch == 0: # To avoid repeated model saving, we do not save after the last # step of training. if current_step < total_training_steps: _save_checkpoint(strategy, checkpoint, model_dir, checkpoint_name.format(step=current_step)) if sub_model_export_name: _save_checkpoint( strategy, sub_model_checkpoint, model_dir, '%s_step_%d.ckpt' % (sub_model_export_name, current_step)) if eval_input_fn: logging.info('Running evaluation after step: %s.', current_step) _run_evaluation(current_step, _get_input_iterator(eval_input_fn, strategy)) # Re-initialize evaluation metric. for metric in eval_metrics + model.metrics: metric.reset_states() _save_checkpoint(strategy, checkpoint, model_dir, checkpoint_name.format(step=current_step)) if sub_model_export_name: _save_checkpoint(strategy, sub_model_checkpoint, model_dir, '%s.ckpt' % sub_model_export_name) if eval_input_fn: logging.info('Running final evaluation after training is complete.') _run_evaluation(current_step, _get_input_iterator(eval_input_fn, strategy)) training_summary = { 'total_training_steps': total_training_steps, 'train_loss': _float_metric_value(train_loss_metric), } if eval_metrics: # TODO(hongkuny): Cleans up summary reporting in text. training_summary['last_train_metrics'] = _float_metric_value( train_metrics[0]) training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0]) write_txt_summary(training_summary, summary_dir) if not _should_export_summary(strategy): tf.io.gfile.rmtree(summary_dir) return model