def initialize(params: base_configs.ExperimentConfig, dataset_builder: dataset_factory.DatasetBuilder): """Initializes backend related initializations.""" keras_utils.set_session_config(enable_eager=params.runtime.run_eagerly, enable_xla=params.runtime.enable_xla) if params.runtime.gpu_threads_enabled: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=params.runtime.per_gpu_thread_count, gpu_thread_mode=params.runtime.gpu_thread_mode, num_gpus=params.runtime.num_gpus, datasets_num_private_threads=params.runtime. dataset_num_private_threads) performance.set_mixed_precision_policy(dataset_builder.dtype) if tf.config.list_physical_devices('GPU'): data_format = 'channels_first' else: data_format = 'channels_last' tf.keras.backend.set_image_data_format(data_format) distribution_utils.configure_cluster(params.runtime.worker_hosts, params.runtime.task_index) if params.runtime.run_eagerly: # Enable eager execution to allow step-by-step debugging tf.config.experimental_run_functions_eagerly(True)
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. NotImplementedError: If some features are not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) performance.set_mixed_precision_policy( flags_core.get_tf_dtype(flags_obj), flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla # Current resnet_model.resnet50 input format is always channel-last. # We use keras_application mobilenet model which input format is depends on # the keras beckend image data format. # This use_keras_image_data_format flags indicates whether image preprocessor # output format should be same as the keras backend image data format or just # channel-last format. use_keras_image_data_format = (flags_obj.model == 'mobilenet') train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, training_dataset_cache=flags_obj.training_dataset_cache, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), dtype=dtype, drop_remainder=drop_remainder) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) with strategy_scope: if flags_obj.optimizer == 'resnet50_default': optimizer = common.get_optimizer(lr_schedule) elif flags_obj.optimizer == 'mobilenet_default': initial_learning_rate = \ flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size optimizer = tf.keras.optimizers.SGD( learning_rate=tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay, decay_rate=flags_obj.lr_decay_factor, staircase=True), momentum=0.9) if flags_obj.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) # TODO(hongkuny): Remove trivial model usage and move it to benchmark. if flags_obj.use_trivial_model: model = test_utils.trivial_model( imagenet_preprocessing.NUM_CLASSES) elif flags_obj.model == 'resnet50_v1.5': model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES) elif flags_obj.model == 'mobilenet': # TODO(kimjaehong): Remove layers attribute when minimum TF version # support 2.0 layers by default. model = tf.keras.applications.mobilenet.MobileNet( weights=None, classes=imagenet_preprocessing.NUM_CLASSES, layers=tf.keras.layers) if flags_obj.pretrained_filepath: model.load_weights(flags_obj.pretrained_filepath) if flags_obj.pruning_method == 'polynomial_decay': if dtype != tf.float32: raise NotImplementedError( 'Pruning is currently only supported on dtype=tf.float32.') pruning_params = { 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay( initial_sparsity=flags_obj.pruning_initial_sparsity, final_sparsity=flags_obj.pruning_final_sparsity, begin_step=flags_obj.pruning_begin_step, end_step=flags_obj.pruning_end_step, frequency=flags_obj.pruning_frequency), } model = tfmot.sparsity.keras.prune_low_magnitude( model, **pruning_params) elif flags_obj.pruning_method: raise NotImplementedError( 'Only polynomial_decay is currently supported.') model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks( steps_per_epoch=steps_per_epoch, pruning_method=flags_obj.pruning_method, enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export, model_dir=flags_obj.model_dir) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if flags_obj.pruning_method: model = tfmot.sparsity.keras.strip_pruning(model) if flags_obj.enable_checkpoint_and_export: if dtype == tf.bfloat16: logging.warning( 'Keras model.save does not support bfloat16 dtype.') else: # Keras model.save assumes a float32 input designature. export_path = os.path.join(flags_obj.model_dir, 'saved_model') model.save(export_path, include_optimizer=False) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla, enable_scoped_allocator=flags_obj.enable_scoped_allocator) # Enable habana bf16 conversion pass only if native keras mixed precision is disabled if flags.FLAGS.dtype == 'bf16' and flags.FLAGS.use_keras_mixed_precision == False: performance.set_mixed_precision_policy(tf.float32) os.environ['TF_BF16_CONVERSION'] = flags.FLAGS.bf16_config_path else: performance.set_mixed_precision_policy( flags_core.get_tf_dtype(flags_obj)) os.environ.setdefault("TF_DISABLE_MKL", "1") os.environ.setdefault("TF_ALLOW_CONTROL_EDGES_IN_HABANA_OPS", "1") # This only affects GPU. common.set_cudnn_batchnorm_mode() # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) batch_size = adjust_batch_size(flags_obj.batch_size) if horovod_enabled(): model_dir = os.path.join(flags_obj.model_dir, "worker_" + str(hvd.rank())) else: model_dir = flags_obj.model_dir hls_addresses = str(os.environ.get("MULTI_HLS_IPS", "127.0.0.1")).split(",") TF_BASE_PORT = 2410 mpi_rank = comm_rank() mpi_size = comm_size() worker_hosts = ",".join([ ",".join([ address + ':' + str(TF_BASE_PORT + rank) for rank in range(mpi_size // len(hls_addresses)) ]) for address in hls_addresses ]) task_index = mpi_rank # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(worker_hosts, task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) train_writer, eval_writer = None, None if flags_obj.enable_tensorboard: train_writer = tf.summary.create_file_writer(model_dir) eval_writer = tf.summary.create_file_writer( os.path.join(model_dir, 'eval')) hparams = flags_obj.flag_values_dict() hparams.setdefault('precision', flags_obj.dtype) write_hparams_v2(train_writer, hparams) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) train_steps = train_epochs * per_epoch_steps logging.info( 'Training %d epochs, each epoch has %d steps, ' 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, train_steps, eval_steps) time_callback = keras_utils.TimeHistory( batch_size, flags_obj.log_steps, summary_writer=train_writer, batch_size_per_node=flags_obj.batch_size) profiler_callback = None if flags_obj.profile_steps is not None: profiler_callback = keras_utils.get_profiler_callback( model_dir, flags_obj.profile_steps, flags_obj.enable_tensorboard, per_epoch_steps) with distribution_utils.get_strategy_scope(strategy): runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, train_steps, per_epoch_steps, profiler_callback) eval_interval = flags_obj.epochs_between_evals * per_epoch_steps checkpoint_interval = (per_epoch_steps if flags_obj.enable_checkpoint_and_export else None) summary_interval = flags_obj.log_steps if flags_obj.enable_tensorboard else None checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval) train_steps = per_epoch_steps * train_epochs resnet_controller = controller.Controller( strategy, runnable.train, runnable.evaluate, global_step=runnable.global_step, steps_per_loop=steps_per_loop, train_steps=train_steps, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, eval_steps=eval_steps, eval_interval=eval_interval, train_summary_writer=train_writer, eval_summary_writer=eval_writer) time_callback.on_train_begin() resnet_controller.train(evaluate=not flags_obj.skip_eval) time_callback.on_train_end() stats = build_stats(runnable, time_callback) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ tf.get_logger().propagate = False output_dir = None if "LOG_DIR" in os.environ: output_dir = os.environ["LOG_DIR"] mlperf_mlloger, mlperf_mllog = get_mllog_mlloger(output_dir) mlperf_mlloger.event(key=mlperf_mllog.constants.CACHE_CLEAR, value=True) mlperf_mlloger.start(key=mlperf_mllog.constants.INIT_START, value=None) mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_BENCHMARK, value=mlperf_mllog.constants.RESNET) mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_ORG, value='Habana') mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_DIVISION, value='closed') mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_PLATFORM, value='gaudi-{}'.format(flags_obj.num_gpus)) mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_STATUS, value='onprem') keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) # This only affects GPU. common.set_cudnn_batchnorm_mode() # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) if horovod_enabled(): model_dir = os.path.join(flags_obj.model_dir, "worker_" + str(hvd.rank())) else: model_dir = flags_obj.model_dir global_batch_size = get_global_batch_size(flags_obj.batch_size) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) mlperf_mlloger.event(key=mlperf_mllog.constants.GLOBAL_BATCH_SIZE, value=global_batch_size) mlperf_mlloger.event(key=mlperf_mllog.constants.TRAIN_SAMPLES, value=imagenet_preprocessing.NUM_IMAGES['train']) mlperf_mlloger.event(key=mlperf_mllog.constants.EVAL_SAMPLES, value=imagenet_preprocessing.NUM_IMAGES['validation']) group_batch_norm = 1 mlperf_mlloger.event(key=mlperf_mllog.constants.MODEL_BN_SPAN, value=flags_obj.batch_size * group_batch_norm) train_writer, eval_writer = None, None if flags_obj.enable_tensorboard: train_writer = tf.summary.create_file_writer(model_dir) eval_writer = tf.summary.create_file_writer( os.path.join(model_dir, 'eval')) hparams = flags_obj.flag_values_dict() write_hparams_v2(train_writer, hparams) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) train_steps = train_epochs * per_epoch_steps logging.info( 'Training %d epochs, each epoch has %d steps, ' 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, train_steps, eval_steps) time_callback = keras_utils.TimeHistory( global_batch_size, flags_obj.log_steps, summary_writer=train_writer, batch_size_per_node=flags_obj.batch_size) profiler_callback = None if flags_obj.profile_steps is not None: profiler_callback = keras_utils.get_profiler_callback( model_dir, flags_obj.profile_steps, flags_obj.enable_tensorboard, per_epoch_steps) with distribution_utils.get_strategy_scope(strategy): runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, train_steps, per_epoch_steps, profiler_callback, mlperf_mlloger, mlperf_mllog) eval_interval = flags_obj.epochs_between_evals * per_epoch_steps eval_offset = flags_obj.eval_offset_epochs * per_epoch_steps if eval_offset != 0: eval_offset -= eval_interval checkpoint_interval = (per_epoch_steps if flags_obj.enable_checkpoint_and_export else None) summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval) device_warmup_steps = (flags_obj.device_warmup_steps if flags_obj.enable_device_warmup else 0) if flags_obj.enable_device_warmup: logging.info('Warmup for %d steps.', device_warmup_steps) train_steps = per_epoch_steps * train_epochs resnet_controller = controller.Controller( strategy, runnable.train, runnable.evaluate, runnable.warmup, global_step=runnable.global_step, steps_per_loop=steps_per_loop, train_steps=train_steps, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, eval_steps=eval_steps, eval_interval=eval_interval, eval_offset=eval_offset, device_warmup_steps=device_warmup_steps, train_summary_writer=train_writer, eval_summary_writer=eval_writer) if flags_obj.enable_device_warmup: resnet_controller.warmup() mlperf_mlloger.end(key=mlperf_mllog.constants.INIT_STOP) hvd.broadcast(0, 0) time_callback.on_train_begin() mlperf_mlloger.start(key=mlperf_mllog.constants.RUN_START) mlperf_mlloger.start( key=mlperf_mllog.constants.BLOCK_START, value=None, metadata={ 'first_epoch_num': 1, 'epoch_count': (flags_obj.eval_offset_epochs if flags_obj.eval_offset_epochs > 0 else flags_obj.epochs_between_evals) }) resnet_controller.train(evaluate=not flags_obj.skip_eval) if not flags_obj.skip_eval: eval_accuracy = resnet_controller.last_eval_output['test_accuracy'] if eval_accuracy >= flags_obj.target_accuracy: mlperf_mlloger.end(key=mlperf_mllog.constants.RUN_STOP, value=None, metadata={'status': 'success'}) else: mlperf_mlloger.end(key=mlperf_mllog.constants.RUN_STOP, value=None, metadata={'status': 'fail'}) time_callback.on_train_end() stats = build_stats(runnable, time_callback) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) # This only affects GPU. common.set_cudnn_batchnorm_mode() # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) logging.info( 'Training %d epochs, each epoch has %d steps, ' 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory( flags_obj.batch_size, flags_obj.log_steps, logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None) profiler_callback = None if flags_obj.profile_steps is not None: profiler_callback = keras_utils.get_profiler_callback( flags_obj.model_dir, flags_obj.profile_steps, flags_obj.enable_tensorboard, per_epoch_steps) with distribution_utils.get_strategy_scope(strategy): runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, per_epoch_steps, profiler_callback) eval_interval = flags_obj.epochs_between_evals * per_epoch_steps checkpoint_interval = ( per_epoch_steps if flags_obj.enable_checkpoint_and_export else None) summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=flags_obj.model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval) resnet_controller = controller.Controller( strategy, runnable.train, runnable.evaluate, global_step=runnable.global_step, steps_per_loop=steps_per_loop, train_steps=per_epoch_steps * train_epochs, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, eval_steps=eval_steps, eval_interval=eval_interval) time_callback.on_train_begin() resnet_controller.train(evaluate=not flags_obj.skip_eval) time_callback.on_train_end() stats = build_stats(runnable, time_callback) return stats