def initialize(params: base_configs.ExperimentConfig, dataset_builder: dataset_factory.DatasetBuilder): """Initializes backend related initializations.""" keras_utils.set_session_config(enable_eager=params.runtime.run_eagerly, enable_xla=params.runtime.enable_xla) if params.runtime.gpu_threads_enabled: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=params.runtime.per_gpu_thread_count, gpu_thread_mode=params.runtime.gpu_thread_mode, num_gpus=params.runtime.num_gpus, datasets_num_private_threads=params.runtime. dataset_num_private_threads) performance.set_mixed_precision_policy(dataset_builder.dtype, get_loss_scale(params)) if tf.config.list_physical_devices('GPU'): data_format = 'channels_first' else: data_format = 'channels_last' tf.keras.backend.set_image_data_format(data_format) distribution_utils.configure_cluster(params.runtime.worker_hosts, params.runtime.task_index) if params.runtime.run_eagerly: # Enable eager execution to allow step-by-step debugging tf.config.experimental_run_functions_eagerly(True)
def run_bert(strategy, input_meta_data, model_config, train_input_fn=None, eval_input_fn=None, init_checkpoint=None, custom_callbacks=None, custom_metrics=None): """Run BERT training.""" # Enables XLA in Session Config. Should not be set for TPU. keras_utils.set_session_config(FLAGS.enable_xla) performance.set_mixed_precision_policy(common_flags.dtype(), use_experimental_api=False) epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch train_data_size = ( input_meta_data['train_data_size'] // FLAGS.num_eval_per_epoch) if FLAGS.train_data_size: train_data_size = min(train_data_size, FLAGS.train_data_size) logging.info('Updated train_data_size: %s', train_data_size) steps_per_epoch = int(train_data_size / FLAGS.train_batch_size) warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size) eval_steps = int( math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size)) if not strategy: raise ValueError('Distribution strategy has not been specified.') if not custom_callbacks: custom_callbacks = [] if FLAGS.log_steps: custom_callbacks.append( keras_utils.TimeHistory( batch_size=FLAGS.train_batch_size, log_steps=FLAGS.log_steps, logdir=FLAGS.model_dir)) trained_model, _ = run_bert_classifier( strategy, model_config, input_meta_data, FLAGS.model_dir, epochs, steps_per_epoch, FLAGS.steps_per_loop, eval_steps, warmup_steps, FLAGS.learning_rate, init_checkpoint or FLAGS.init_checkpoint, train_input_fn, eval_input_fn, custom_callbacks=custom_callbacks, custom_metrics=custom_metrics) if FLAGS.model_export_path: model_saving_utils.export_bert_model( FLAGS.model_export_path, model=trained_model) return trained_model
def run(callbacks=None): keras_utils.set_session_config(enable_xla=FLAGS.enable_xla) params = config_factory.config_generator(FLAGS.model) params = params_dict.override_params_dict(params, FLAGS.config_file, is_strict=True) params = params_dict.override_params_dict(params, FLAGS.params_override, is_strict=True) params.override( { 'strategy_type': FLAGS.strategy_type, 'model_dir': FLAGS.model_dir, 'strategy_config': executor.strategy_flags_dict(), }, is_strict=False) params.validate() params.lock() pp = pprint.PrettyPrinter() params_str = pp.pformat(params.as_dict()) logging.info('Model Parameters: {}'.format(params_str)) train_input_fn = None eval_input_fn = None training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern if not training_file_pattern and not eval_file_pattern: raise ValueError( 'Must provide at least one of training_file_pattern and ' 'eval_file_pattern.') if training_file_pattern: # Use global batch size for single host. train_input_fn = input_reader.InputFn( file_pattern=training_file_pattern, params=params, mode=input_reader.ModeKeys.TRAIN, batch_size=params.train.batch_size) if eval_file_pattern: eval_input_fn = input_reader.InputFn( file_pattern=eval_file_pattern, params=params, mode=input_reader.ModeKeys.PREDICT_WITH_GT, batch_size=params.eval.batch_size, num_examples=params.eval.eval_samples) return run_executor(params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, callbacks=callbacks)
def run(flags_obj): """Run Shakespeare training and predict. Args: flags_obj: An object containing parsed flag values. Returns: Dictionary with status from the run. """ if not flags_obj.training_data: raise ValueError( 'Must set the path to a training data file. e.g download the following ' 'https://storage.googleapis.com/download.tensorflow.org/data/' 'shakespeare.txt') if flags_obj.dtype == 'fp16': policy = tf.keras.mixed_precision.experimental.Policy( 'mixed_float16', loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16='dynamic')) tf.keras.mixed_precision.experimental.set_policy(policy) keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus) dataset, idx2char, char2idx = get_dataset(flags_obj.training_data, batch_size=flags_obj.batch_size) stats = {} if flags_obj.train: history, callbacks = train_model(flags_obj, dataset, len(idx2char), strategy, checkpoint_dir=flags_obj.model_dir) stats['history'] = history.history stats['callbacks'] = callbacks if flags_obj.predict_context: if not flags_obj.model_dir: raise ValueError('Must set model_dir to get predictions.') print(make_prediction(flags_obj.model_dir, flags_obj.predict_length, flags_obj.predict_context, idx2char, char2idx)) return stats
def initialize(params: base_configs.ExperimentConfig, dataset_builder: dataset_factory.DatasetBuilder): """Initializes backend related initializations.""" keras_utils.set_session_config(enable_xla=params.runtime.enable_xla) performance.set_mixed_precision_policy(dataset_builder.dtype) if tf.config.list_physical_devices('GPU'): data_format = 'channels_first' else: data_format = 'channels_last' tf.keras.backend.set_image_data_format(data_format) if params.runtime.run_eagerly: # Enable eager execution to allow step-by-step debugging tf.config.experimental_run_functions_eagerly(True) if tf.config.list_physical_devices('GPU'): if params.runtime.gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=params.runtime.per_gpu_thread_count, gpu_thread_mode=params.runtime.gpu_thread_mode, num_gpus=params.runtime.num_gpus, datasets_num_private_threads=params.runtime. dataset_num_private_threads) # pylint:disable=line-too-long if params.runtime.batchnorm_spatial_persistent: os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
def run_ncf(_): """Run NCF training and eval with Keras.""" keras_utils.set_session_config(enable_xla=FLAGS.enable_xla) if FLAGS.seed is not None: print("Setting tf seed") tf.random.set_seed(FLAGS.seed) model_helpers.apply_clean(FLAGS) if FLAGS.dtype == "fp16" and FLAGS.fp16_implementation == "keras": policy = tf.keras.mixed_precision.experimental.Policy( "mixed_float16", loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic")) tf.keras.mixed_precision.experimental.set_policy(policy) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, tpu_address=FLAGS.tpu) params = ncf_common.parse_flags(FLAGS) params["distribute_strategy"] = strategy
def run_ncf(_): """Run NCF training and eval with Keras.""" keras_utils.set_session_config(enable_xla=FLAGS.enable_xla) if FLAGS.seed is not None: print("Setting tf seed") tf.random.set_seed(FLAGS.seed) model_helpers.apply_clean(FLAGS) if FLAGS.dtype == "fp16" and FLAGS.fp16_implementation == "keras": tf.keras.mixed_precision.set_global_policy("mixed_float16") strategy = distribute_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, tpu_address=FLAGS.tpu) params = ncf_common.parse_flags(FLAGS) params["distribute_strategy"] = strategy params["use_tpu"] = (FLAGS.distribution_strategy == "tpu") if params["use_tpu"] and not params["keras_use_ctl"]: logging.error( "Custom training loop must be used when using TPUStrategy.") return batch_size = params["batch_size"] time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps) callbacks = [time_callback] producer, input_meta_data = None, None generate_input_online = params["train_dataset_path"] is None if generate_input_online: # Start data producing thread. num_users, num_items, _, _, producer = ncf_common.get_inputs(params) producer.start() per_epoch_callback = IncrementEpochCallback(producer) callbacks.append(per_epoch_callback) else: assert params["eval_dataset_path"] and params["input_meta_data_path"] with tf.io.gfile.GFile(params["input_meta_data_path"], "rb") as reader: input_meta_data = json.loads(reader.read().decode("utf-8")) num_users = input_meta_data["num_users"] num_items = input_meta_data["num_items"] params["num_users"], params["num_items"] = num_users, num_items if FLAGS.early_stopping: early_stopping_callback = CustomEarlyStopping( "val_HR_METRIC", desired_value=FLAGS.hr_threshold) callbacks.append(early_stopping_callback) (train_input_dataset, eval_input_dataset, num_train_steps, num_eval_steps) = \ (ncf_input_pipeline.create_ncf_input_data( params, producer, input_meta_data, strategy)) steps_per_epoch = None if generate_input_online else num_train_steps with distribute_utils.get_strategy_scope(strategy): keras_model = _get_keras_model(params) optimizer = tf.keras.optimizers.Adam( learning_rate=params["learning_rate"], beta_1=params["beta1"], beta_2=params["beta2"], epsilon=params["epsilon"]) if FLAGS.fp16_implementation == "graph_rewrite": optimizer = \ tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic")) elif FLAGS.dtype == "fp16": loss_scale = flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic") # Note Model.compile automatically wraps the optimizer with a # LossScaleOptimizer using dynamic loss scaling. We explicitly wrap it # here for the case where a custom training loop or fixed loss scale is # used. if loss_scale == "dynamic": optimizer = tf.keras.mixed_precision.LossScaleOptimizer( optimizer) else: optimizer = tf.keras.mixed_precision.LossScaleOptimizer( optimizer, dynamic=False, initial_scale=loss_scale) if params["keras_use_ctl"]: train_loss, eval_results = run_ncf_custom_training( params, strategy, keras_model, optimizer, callbacks, train_input_dataset, eval_input_dataset, num_train_steps, num_eval_steps, generate_input_online=generate_input_online) else: keras_model.compile(optimizer=optimizer, run_eagerly=FLAGS.run_eagerly) if not FLAGS.ml_perf: # Create Tensorboard summary and checkpoint callbacks. summary_dir = os.path.join(FLAGS.model_dir, "summaries") summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) checkpoint_path = os.path.join(FLAGS.model_dir, "checkpoint") checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( checkpoint_path, save_weights_only=True) callbacks += [summary_callback, checkpoint_callback] history = keras_model.fit(train_input_dataset, epochs=FLAGS.train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_data=eval_input_dataset, validation_steps=num_eval_steps, verbose=2) logging.info("Training done. Start evaluating") eval_loss_and_metrics = keras_model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) logging.info("Keras evaluation is done.") # Keras evaluate() API returns scalar loss and metric values from # evaluation as a list. Here, the returned list would contain # [evaluation loss, hr sum, hr count]. eval_hit_rate = eval_loss_and_metrics[1] / eval_loss_and_metrics[2] # Format evaluation result into [eval loss, eval hit accuracy]. eval_results = [eval_loss_and_metrics[0], eval_hit_rate] if history and history.history: train_history = history.history train_loss = train_history["loss"][-1] stats = build_stats(train_loss, eval_results, time_callback) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. NotImplementedError: If some features are not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) performance.set_mixed_precision_policy( flags_core.get_tf_dtype(flags_obj), flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.config.list_physical_devices('GPU') else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla # Current resnet_model.resnet50 input format is always channel-last. # We use keras_application mobilenet model which input format is depends on # the keras beckend image data format. # This use_keras_image_data_format flags indicates whether image preprocessor # output format should be same as the keras backend image data format or just # channel-last format. use_keras_image_data_format = (flags_obj.model == 'mobilenet') train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, training_dataset_cache=flags_obj.training_dataset_cache, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), dtype=dtype, drop_remainder=drop_remainder) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) steps_per_epoch = ( imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) with strategy_scope: if flags_obj.optimizer == 'resnet50_default': optimizer = common.get_optimizer(lr_schedule) elif flags_obj.optimizer == 'mobilenet_default': initial_learning_rate = \ flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size optimizer = tf.keras.optimizers.SGD( learning_rate=tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay, decay_rate=flags_obj.lr_decay_factor, staircase=True), momentum=0.9) if flags_obj.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) # TODO(hongkuny): Remove trivial model usage and move it to benchmark. if flags_obj.use_trivial_model: model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES) elif flags_obj.model == 'resnet50_v1.5': model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES) elif flags_obj.model == 'mobilenet': # TODO(kimjaehong): Remove layers attribute when minimum TF version # support 2.0 layers by default. model = tf.keras.applications.mobilenet.MobileNet( weights=None, classes=imagenet_preprocessing.NUM_CLASSES, layers=tf.keras.layers) if flags_obj.pretrained_filepath: model.load_weights(flags_obj.pretrained_filepath) if flags_obj.pruning_method == 'polynomial_decay': if dtype != tf.float32: raise NotImplementedError( 'Pruning is currently only supported on dtype=tf.float32.') pruning_params = { 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay( initial_sparsity=flags_obj.pruning_initial_sparsity, final_sparsity=flags_obj.pruning_final_sparsity, begin_step=flags_obj.pruning_begin_step, end_step=flags_obj.pruning_end_step, frequency=flags_obj.pruning_frequency), } model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params) elif flags_obj.pruning_method: raise NotImplementedError( 'Only polynomial_decay is currently supported.') model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks( pruning_method=flags_obj.pruning_method, enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export, model_dir=flags_obj.model_dir) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = ( imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if flags_obj.pruning_method: model = tfmot.sparsity.keras.strip_pruning(model) if flags_obj.enable_checkpoint_and_export: if dtype == tf.bfloat16: logging.warning('Keras model.save does not support bfloat16 dtype.') else: # Keras model.save assumes a float32 input designature. export_path = os.path.join(flags_obj.model_dir, 'saved_model') model.save(export_path, include_optimizer=False) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def train(self): """Trains the model.""" params, flags_obj, is_train = self.params, self.flags_obj, True # Sets config options. keras_utils.set_session_config( enable_xla=flags_obj.enable_xla) _ensure_dir(flags_obj.model_dir) if self.distribution_strategy: with self.distribution_strategy.scope(): model = transformer.create_model(params, is_train) opt = self._create_optimizer() model.compile(opt) else: model = transformer.create_model(params, is_train) opt = self._create_optimizer() model.compile(opt) model.summary() train_ds = data_pipeline.train_input_fn(params) map_data_fn = data_pipeline.map_data_for_transformer_fn train_ds = train_ds.map(map_data_fn, num_parallel_calls=params["num_parallel_calls"]) callbacks = self._create_callbacks(flags_obj.model_dir, 0, params) if flags_obj.train_steps < flags_obj.steps_between_evals: flags_obj.steps_between_evals = flags_obj.train_steps iterations = flags_obj.train_steps // flags_obj.steps_between_evals cased_score, uncased_score = None, None cased_score_history, uncased_score_history = [], [] for i in range(1, iterations + 1): print("Start train iteration:{}/{}".format(i, iterations)) history = model.fit( train_ds, initial_epoch=i-1, epochs=i, steps_per_epoch=flags_obj.steps_between_evals, callbacks=callbacks, # If TimeHistory is enabled, progress bar would be messy. Increase the # verbose level to get rid of it. verbose=(2 if flags_obj.enable_time_history else 1)) print("End train iteration:{}/{} global step:{}".format( i, iterations, i*flags_obj.steps_between_evals)) tf.compat.v1.logging.info("Train history: {}".format(history.history)) stats = misc.build_stats(history, callbacks) if (flags_obj.bleu_source and flags_obj.bleu_ref): uncased_score, cased_score = self.eval() cased_score_history.append([i, cased_score]) uncased_score_history.append([i, uncased_score]) stats = misc.build_stats(history, callbacks) if uncased_score and cased_score: stats["bleu_uncased"] = uncased_score stats["bleu_cased"] = cased_score stats["bleu_uncased_history"] = uncased_score_history stats["bleu_cased_history"] = cased_score_history return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) # This only affects GPU. common.set_cudnn_batchnorm_mode() # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.config.list_physical_devices('GPU') else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) logging.info( 'Training %d epochs, each epoch has %d steps, ' 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory( flags_obj.batch_size, flags_obj.log_steps, logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None) with distribution_utils.get_strategy_scope(strategy): runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, per_epoch_steps) eval_interval = flags_obj.epochs_between_evals * per_epoch_steps checkpoint_interval = (per_epoch_steps if flags_obj.enable_checkpoint_and_export else None) summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=flags_obj.model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval) resnet_controller = controller.Controller( strategy, runnable.train, runnable.evaluate, global_step=runnable.global_step, steps_per_loop=steps_per_loop, train_steps=per_epoch_steps * train_epochs, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, eval_steps=eval_steps, eval_interval=eval_interval) time_callback.on_train_begin() resnet_controller.train(evaluate=not flags_obj.skip_eval) time_callback.on_train_end() stats = build_stats(runnable, time_callback) return stats
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_common.set_gpu_thread_mode_and_count(flags_obj) keras_common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=cifar_preprocessing.HEIGHT, width=cifar_preprocessing.WIDTH, num_channels=cifar_preprocessing.NUM_CHANNELS, num_classes=cifar_preprocessing.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj), drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = cifar_preprocessing.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, # Setting drop_remainder to avoid the partial batch logic in normalization # layer, which triggers tf.where and leads to extra memory copy of input # sizes between host and GPU. drop_remainder=(not flags_obj.enable_get_next_as_optional)) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_preprocessing.parse_record) with strategy_scope: optimizer = keras_common.get_optimizer() model = resnet_cifar_model.resnet56( classes=cifar_preprocessing.NUM_CLASSES) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='categorical_crossentropy', optimizer=optimizer, metrics=(['categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj. force_v2_in_keras_compile) else: model.compile( loss='categorical_crossentropy', optimizer=optimizer, metrics=(['categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) callbacks = keras_common.get_callbacks( learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train']) train_steps = cifar_preprocessing.NUM_IMAGES[ 'train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = keras_common.build_stats(history, eval_output, callbacks) return stats
def train(self): """Trains the model.""" params = self.params flags_obj = self.flags_obj # Sets config options. keras_utils.set_session_config(enable_xla=flags_obj.enable_xla) _ensure_dir(flags_obj.model_dir) with distribute_utils.get_strategy_scope(self.distribution_strategy): model = transformer.create_model(params, is_train=True) opt = self._create_optimizer() current_step = 0 checkpoint = tf.train.Checkpoint(model=model, optimizer=opt) latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir) if latest_checkpoint: checkpoint.restore(latest_checkpoint) logging.info("Loaded checkpoint %s", latest_checkpoint) current_step = opt.iterations.numpy() if params["use_ctl"]: train_loss_metric = tf.keras.metrics.Mean( "training_loss", dtype=tf.float32) if params["enable_tensorboard"]: summary_writer = tf.summary.create_file_writer( os.path.join(flags_obj.model_dir, "summary")) else: summary_writer = tf.summary.create_noop_writer() train_metrics = [train_loss_metric] if params["enable_metrics_in_training"]: train_metrics = train_metrics + model.metrics else: model.compile(opt) model.summary() if self.use_tpu: # Different from experimental_distribute_dataset, # distribute_datasets_from_function requires # per-replica/local batch size. params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync train_ds = ( self.distribution_strategy.distribute_datasets_from_function( lambda ctx: data_pipeline.train_input_fn(params, ctx))) else: train_ds = data_pipeline.train_input_fn(params) map_data_fn = data_pipeline.map_data_for_transformer_fn train_ds = train_ds.map( map_data_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) if params["use_ctl"]: train_ds_iterator = iter(train_ds) callbacks = self._create_callbacks(flags_obj.model_dir, params) # Only TimeHistory callback is supported for CTL if params["use_ctl"]: callbacks = [cb for cb in callbacks if isinstance(cb, keras_utils.TimeHistory)] @tf.function def train_steps(iterator, steps): """Training steps function for TPU runs. Args: iterator: The input iterator of the training dataset. steps: An integer, the number of training steps. Returns: A float, the loss value. """ def _step_fn(inputs): """Per-replica step function.""" inputs, targets = inputs with tf.GradientTape() as tape: logits = model([inputs, targets], training=True) loss = metrics.transformer_loss(logits, targets, params["label_smoothing"], params["vocab_size"]) # Scales the loss, which results in using the average loss across all # of the replicas for backprop. scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync # De-dupes variables due to keras tracking issues. tvars = list({id(v): v for v in model.trainable_variables}.values()) grads = tape.gradient(scaled_loss, tvars) opt.apply_gradients(zip(grads, tvars)) # For reporting, the metric takes the mean of losses. train_loss_metric.update_state(loss) for _ in tf.range(steps): train_loss_metric.reset_states() self.distribution_strategy.run( _step_fn, args=(next(iterator),)) cased_score, uncased_score = None, None cased_score_history, uncased_score_history = [], [] while current_step < flags_obj.train_steps: remaining_steps = flags_obj.train_steps - current_step train_steps_per_eval = ( remaining_steps if remaining_steps < flags_obj.steps_between_evals else flags_obj.steps_between_evals) current_iteration = current_step // flags_obj.steps_between_evals logging.info( "Start train iteration at global step:{}".format(current_step)) history = None if params["use_ctl"]: if not self.use_tpu: raise NotImplementedError( "Custom training loop on GPUs is not implemented.") # Runs training steps. with summary_writer.as_default(): for cb in callbacks: cb.on_epoch_begin(current_iteration) cb.on_batch_begin(0) train_steps( train_ds_iterator, tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32)) current_step += train_steps_per_eval train_loss = train_loss_metric.result().numpy().astype(float) logging.info("Train Step: %d/%d / loss = %s", current_step, flags_obj.train_steps, train_loss) for cb in callbacks: cb.on_batch_end(train_steps_per_eval - 1) cb.on_epoch_end(current_iteration) if params["enable_tensorboard"]: for metric_obj in train_metrics: tf.summary.scalar(metric_obj.name, metric_obj.result(), current_step) summary_writer.flush() for cb in callbacks: cb.on_train_end() if flags_obj.enable_checkpointing: # avoid check-pointing when running for benchmarking. checkpoint_name = checkpoint.save( os.path.join(flags_obj.model_dir, "ctl_step_{}.ckpt".format(current_step))) logging.info("Saved checkpoint to %s", checkpoint_name) else: if self.use_tpu: raise NotImplementedError( "Keras model.fit on TPUs is not implemented.") history = model.fit( train_ds, initial_epoch=current_iteration, epochs=current_iteration + 1, steps_per_epoch=train_steps_per_eval, callbacks=callbacks, # If TimeHistory is enabled, progress bar would be messy. Increase # the verbose level to get rid of it. verbose=(2 if flags_obj.enable_time_history else 1)) current_step += train_steps_per_eval logging.info("Train history: {}".format(history.history)) logging.info("End train iteration at global step:{}".format(current_step)) if (flags_obj.bleu_source and flags_obj.bleu_ref): uncased_score, cased_score = self.eval() cased_score_history.append([current_iteration + 1, cased_score]) uncased_score_history.append([current_iteration + 1, uncased_score]) stats = ({ "loss": train_loss } if history is None else {}) misc.update_stats(history, stats, callbacks) if uncased_score and cased_score: stats["bleu_uncased"] = uncased_score stats["bleu_cased"] = cased_score stats["bleu_uncased_history"] = uncased_score_history stats["bleu_cased_history"] = cased_score_history return stats
def train(self): """Trains the model.""" params = self.params flags_obj = self.flags_obj # Sets config options. keras_utils.set_session_config(enable_xla=flags_obj.enable_xla) train_ds = self._create_dataset(params['data_dir'], repeat=None) val_ds = self._create_dataset(params['val_data_dir'] or re.sub( r'-training.*', '-test.dtitle.tokenized.gz', params['data_dir']), repeat=1) val_ds = val_ds.take(flags_obj.validation_example_count // params["batch_size"]).cache() with distribution_utils.get_strategy_scope(self.distribution_strategy): model = self.create_model(mode='train') model.compile(optimizer=self._create_optimizer(params), loss=self._create_loss_fn(params)) if not os.path.exists(flags_obj.model_dir): os.mkdir(flags_obj.model_dir) current_step = 0 checkpoint = tf.train.Checkpoint(model=model) ckpt_mgr = tf.train.CheckpointManager(checkpoint, flags_obj.model_dir, max_to_keep=3, keep_checkpoint_every_n_hours=24) if ckpt_mgr.latest_checkpoint: #self._print_variables_and_exit(flags_obj.model_dir) model.fit([ tf.ones([params["batch_size"], params['max_input_length']], tf.int32), tf.ones([params["batch_size"], params['max_target_length']], tf.int32) ], tf.ones( [params["batch_size"], params['max_target_length']], tf.int32), verbose=0) checkpoint.restore(ckpt_mgr.latest_checkpoint).assert_consumed() current_step = model.optimizer.iterations.numpy() - 1 logging.info("Loaded checkpoint %s, current_step %d", ckpt_mgr.latest_checkpoint, current_step) if current_step >= flags_obj.train_steps: logging.info("Reach the target train_steps({}) and exit.".format( flags_obj.train_steps)) return None logging.info(f'Start train iteration at global step: {current_step}') model.summary() #print(model.variables) history = model.fit( train_ds, initial_epoch=current_step // flags_obj.steps_between_evals, epochs=(flags_obj.train_steps - 1) // flags_obj.steps_between_evals + 1, steps_per_epoch=min(flags_obj.steps_between_evals, flags_obj.train_steps - current_step), callbacks=self._create_callbacks(flags_obj.model_dir, current_step, flags_obj.steps_between_evals, params, ckpt_mgr), validation_data=val_ds, validation_steps=flags_obj.validation_example_count // params["batch_size"], # redundant but suppress one warining verbose=1) logging.info("Train history: {}".format(history.history)) current_step = model.optimizer.iterations.numpy() - 1 logging.info( "End train iteration at global step:{}".format(current_step)) return history
def run_predict(flags_obj, datasets_override=None, strategy_override=None): keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=1, datasets_num_private_threads=flags_obj.datasets_num_private_threads) common.set_cudnn_batchnorm_mode() performance.set_mixed_precision_policy( flags_core.get_tf_dtype(flags_obj), flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=1, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) distribution_utils.undo_set_up_synthetic_data() train_input_dataset, eval_input_dataset, tr_dataset, te_dataset = setup_datasets(flags_obj, shuffle=False, save_labels=True) pred_input_dataset, pred_dataset = eval_input_dataset, te_dataset with strategy_scope: model = build_model(imagenet_preprocessing.NUM_CLASSES, mode='resnet50_features', save_labels=True) load_path = GB_OPTIONS.pretrained_filepath if load_path is None: load_path = GB_OPTIONS.checkpoint_folder latest = tf.train.latest_checkpoint(load_path) print(latest) model.load_weights(latest) num_eval_steps = imagenet_preprocessing.NUM_IMAGES['validation'] // GB_OPTIONS.batch_size pred = model.predict( pred_input_dataset, batch_size = GB_OPTIONS.batch_size, steps = num_eval_steps ) np.save(GB_OPTIONS.out_npys_folder+'out_X', pred[0]) np.save(GB_OPTIONS.out_npys_folder+'out_labels', pred[1]) np.save(GB_OPTIONS.out_npys_folder+'out_ori_labels', pred[2]) return 'good'
def run_train(flags_obj): keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads) common.set_cudnn_batchnorm_mode() performance.set_mixed_precision_policy( flags_core.get_tf_dtype(flags_obj), flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) distribution_utils.undo_set_up_synthetic_data() train_input_dataset, eval_input_dataset, tr_dataset, te_dataset = setup_datasets(flags_obj) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=GB_OPTIONS.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] // GB_OPTIONS.batch_size) with strategy_scope: optimizer = common.get_optimizer(lr_schedule) model = build_model(imagenet_preprocessing.NUM_CLASSES, mode='resnet50') if GB_OPTIONS.pretrained_filepath is not None: latest = tf.train.latest_checkpoint(GB_OPTIONS.pretrained_filepath) print(latest) model.load_weights(latest) #losses = ["sparse_categorical_crossentropy"] #lossWeights = [1.0] model.compile( optimizer=optimizer, loss="sparse_categorical_crossentropy", #loss_weights=lossWeights, metrics=['sparse_categorical_accuracy']) train_epochs = GB_OPTIONS.num_epochs if not hasattr(tr_dataset, "n_poison"): n_poison=0 n_cover=0 else: n_poison = tr_dataset.n_poison n_cover = tr_dataset.n_cover callbacks = common.get_callbacks( steps_per_epoch=steps_per_epoch, pruning_method=flags_obj.pruning_method, enable_checkpoint_and_export=False, model_dir=GB_OPTIONS.checkpoint_folder ) ckpt_full_path = os.path.join(GB_OPTIONS.checkpoint_folder, 'model.ckpt-{epoch:04d}-p%d-c%d'%(n_poison,n_cover)) callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True, save_best_only=True)) num_eval_steps = imagenet_preprocessing.NUM_IMAGES['validation'] // GB_OPTIONS.batch_size if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None eval_input_dataset = None history = model.fit( train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=eval_input_dataset, validation_freq=flags_obj.epochs_between_evals ) export_path = os.path.join(GB_OPTIONS.checkpoint_folder, 'saved_model') model.save(export_path, include_optimizer=False) eval_output = model.evaluate( eval_input_dataset, steps=num_eval_steps, verbose=2 ) stats = common.build_stats(history, eval_output, callbacks) cmmd = 'cp config.py '+GB_OPTIONS.checkpoint_folder os.system(cmmd) return stats
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=cifar_main.HEIGHT, width=cifar_main.WIDTH, num_channels=cifar_main.NUM_CHANNELS, num_classes=cifar_main.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj)) else: distribution_utils.undo_set_up_synthetic_data() input_fn = cifar_main.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_main.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn(is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_main.parse_record) with strategy_scope: optimizer = keras_common.get_optimizer() model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) callbacks = keras_common.get_callbacks(learning_rate_schedule, cifar_main.NUM_IMAGES['train']) train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (cifar_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = keras_common.build_stats(history, eval_output, callbacks) return stats
def run_ncf(_): """Run NCF training and eval with Keras.""" keras_utils.set_session_config(enable_xla=FLAGS.enable_xla) if FLAGS.seed is not None: print("Setting tf seed") tf.random.set_seed(FLAGS.seed) # TODO(seemuch): Support different train and eval batch sizes if FLAGS.eval_batch_size != FLAGS.batch_size: logging.warning( "The Keras implementation of NCF currently does not support batch_size " "!= eval_batch_size ({} vs. {}). Overriding eval_batch_size to match " "batch_size".format(FLAGS.eval_batch_size, FLAGS.batch_size)) FLAGS.eval_batch_size = FLAGS.batch_size params = ncf_common.parse_flags(FLAGS) model_helpers.apply_clean(flags.FLAGS) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus) params["distribute_strategy"] = strategy if not keras_utils.is_v2_0() and strategy is not None: logging.error( "NCF Keras only works with distribution strategy in TF 2.0") return if (params["keras_use_ctl"] and (not keras_utils.is_v2_0() or strategy is None)): logging.error( "Custom training loop only works with tensorflow 2.0 and dist strat." ) return # ncf_common rounds eval_batch_size (this is needed due to a reshape during # eval). This carries over that rounding to batch_size as well. This is the # per device batch size params["batch_size"] = params["eval_batch_size"] batch_size = params["batch_size"] time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps) callbacks = [time_callback] producer, input_meta_data = None, None generate_input_online = params["train_dataset_path"] is None if generate_input_online: # Start data producing thread. num_users, num_items, num_train_steps, num_eval_steps, producer = ( ncf_common.get_inputs(params)) producer.start() per_epoch_callback = IncrementEpochCallback(producer) callbacks.append(per_epoch_callback) else: assert params["eval_dataset_path"] and params["input_meta_data_path"] with tf.gfile.GFile(params["input_meta_data_path"], "rb") as reader: input_meta_data = json.loads(reader.read().decode("utf-8")) num_users = input_meta_data["num_users"] num_items = input_meta_data["num_items"] params["num_users"], params["num_items"] = num_users, num_items (train_input_dataset, eval_input_dataset, num_train_steps, num_eval_steps) = \ (ncf_input_pipeline.create_ncf_input_data( params, producer, input_meta_data)) steps_per_epoch = None if generate_input_online else num_train_steps if FLAGS.early_stopping: early_stopping_callback = CustomEarlyStopping( "val_HR_METRIC", desired_value=FLAGS.hr_threshold) callbacks.append(early_stopping_callback) with distribution_utils.get_strategy_scope(strategy): keras_model = _get_keras_model(params) optimizer = tf.keras.optimizers.Adam( learning_rate=params["learning_rate"], beta_1=params["beta1"], beta_2=params["beta2"], epsilon=params["epsilon"]) if params["keras_use_ctl"]: loss_object = tf.keras.losses.SparseCategoricalCrossentropy( reduction="sum", from_logits=True) train_input_iterator = strategy.make_dataset_iterator( train_input_dataset) eval_input_iterator = strategy.make_dataset_iterator( eval_input_dataset) @tf.function def train_step(): """Called once per step to train the model.""" def step_fn(features): """Computes loss and applied gradient per replica.""" with tf.GradientTape() as tape: softmax_logits = keras_model(features) labels = features[rconst.TRAIN_LABEL_KEY] loss = loss_object( labels, softmax_logits, sample_weight=features[rconst.VALID_POINT_MASK]) loss *= (1.0 / (batch_size * strategy.num_replicas_in_sync)) grads = tape.gradient(loss, keras_model.trainable_variables) # Converting gradients to dense form helps in perf on GPU for NCF grads = neumf_model.sparse_to_dense_grads( list(zip(grads, keras_model.trainable_variables))) optimizer.apply_gradients(grads) return loss per_replica_losses = strategy.experimental_run( step_fn, train_input_iterator) mean_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) return mean_loss @tf.function def eval_step(): """Called once per eval step to compute eval metrics.""" def step_fn(features): """Computes eval metrics per replica.""" softmax_logits = keras_model(features) in_top_k, metric_weights = metric_fn( softmax_logits, features[rconst.DUPLICATE_MASK], params) hr_sum = tf.reduce_sum(in_top_k * metric_weights) hr_count = tf.reduce_sum(metric_weights) return hr_sum, hr_count per_replica_hr_sum, per_replica_hr_count = ( strategy.experimental_run(step_fn, eval_input_iterator)) hr_sum = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None) hr_count = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_hr_count, axis=None) return hr_sum, hr_count time_callback.on_train_begin() for epoch in range(FLAGS.train_epochs): for cb in callbacks: cb.on_epoch_begin(epoch) # As NCF dataset is sampled with randomness, not repeating # data elements in each epoch has significant impact on # convergence. As so, offline-generated TF record files # contains all epoch worth of data. Thus we do not need # to initialize dataset when reading from tf record files. if generate_input_online: train_input_iterator.initialize() train_loss = 0 for step in range(num_train_steps): time_callback.on_batch_begin(step + epoch * num_train_steps) train_loss += train_step() time_callback.on_batch_end(step + epoch * num_train_steps) train_loss /= num_train_steps logging.info("Done training epoch %s, epoch loss=%s.", epoch + 1, train_loss) eval_input_iterator.initialize() hr_sum = 0 hr_count = 0 for _ in range(num_eval_steps): step_hr_sum, step_hr_count = eval_step() hr_sum += step_hr_sum hr_count += step_hr_count logging.info("Done eval epoch %s, hr=%s.", epoch + 1, hr_sum / hr_count) if (FLAGS.early_stopping and float(hr_sum / hr_count) > params["hr_threshold"]): break time_callback.on_train_end() eval_results = [None, hr_sum / hr_count] else: with distribution_utils.get_strategy_scope(strategy): keras_model.compile( optimizer=optimizer, run_eagerly=FLAGS.run_eagerly, run_distributed=FLAGS.force_v2_in_keras_compile) history = keras_model.fit(train_input_dataset, epochs=FLAGS.train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_data=eval_input_dataset, validation_steps=num_eval_steps, verbose=2) logging.info("Training done. Start evaluating") eval_results = keras_model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) logging.info("Keras evaluation is done.") if history and history.history: train_history = history.history train_loss = train_history["loss"][-1] stats = build_stats(train_loss, eval_results, time_callback) return stats
def run_bert(strategy, input_meta_data, model_config, train_input_fn=None, eval_input_fn=None, init_checkpoint=None, custom_callbacks=None): """Run BERT training.""" if FLAGS.mode == 'export_only': # As Keras ModelCheckpoint callback used with Keras compile/fit() API # internally uses model.save_weights() to save checkpoints, we must # use model.load_weights() when Keras compile/fit() is used. export_classifier(FLAGS.model_export_path, input_meta_data, FLAGS.use_keras_compile_fit, model_config, FLAGS.model_dir) return if FLAGS.mode != 'train_and_eval': raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode) # Enables XLA in Session Config. Should not be set for TPU. keras_utils.set_session_config(FLAGS.enable_xla) performance.set_mixed_precision_policy(common_flags.dtype()) epochs = FLAGS.num_train_epochs train_data_size = input_meta_data['train_data_size'] steps_per_epoch = int(train_data_size / FLAGS.train_batch_size) warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size) eval_steps = int( math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size)) if not strategy: raise ValueError('Distribution strategy has not been specified.') if not custom_callbacks: custom_callbacks = [] if FLAGS.log_steps: custom_callbacks.append( keras_utils.TimeHistory(batch_size=FLAGS.train_batch_size, log_steps=FLAGS.log_steps, logdir=FLAGS.model_dir)) trained_model = run_bert_classifier( strategy, model_config, input_meta_data, FLAGS.model_dir, epochs, steps_per_epoch, FLAGS.steps_per_loop, eval_steps, warmup_steps, FLAGS.learning_rate, init_checkpoint or FLAGS.init_checkpoint, train_input_fn, eval_input_fn, run_eagerly=FLAGS.run_eagerly, use_keras_compile_fit=FLAGS.use_keras_compile_fit, custom_callbacks=custom_callbacks) if FLAGS.model_export_path: # As Keras ModelCheckpoint callback used with Keras compile/fit() API # internally uses model.save_weights() to save checkpoints, we must # use model.load_weights() when Keras compile/fit() is used. model_saving_utils.export_bert_model( FLAGS.model_export_path, model=trained_model, restore_model_using_load_weights=FLAGS.use_keras_compile_fit) return trained_model
def run_ncf(_): """Run NCF training and eval with Keras.""" keras_utils.set_session_config(enable_xla=FLAGS.enable_xla) if FLAGS.seed is not None: print("Setting tf seed") tf.random.set_seed(FLAGS.seed) # TODO(seemuch): Support different train and eval batch sizes if FLAGS.eval_batch_size != FLAGS.batch_size: logging.warning( "The Keras implementation of NCF currently does not support batch_size " "!= eval_batch_size ({} vs. {}). Overriding eval_batch_size to match " "batch_size".format(FLAGS.eval_batch_size, FLAGS.batch_size)) FLAGS.eval_batch_size = FLAGS.batch_size params = ncf_common.parse_flags(FLAGS) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus) params["distribute_strategy"] = strategy if (params["keras_use_ctl"] and (not keras_utils.is_v2_0() or strategy is None)): logging.error( "Custom training loop only works with tensorflow 2.0 and dist strat." ) return # ncf_common rounds eval_batch_size (this is needed due to a reshape during # eval). This carries over that rounding to batch_size as well. This is the # per device batch size params["batch_size"] = params["eval_batch_size"] batch_size = params["batch_size"] num_users, num_items, num_train_steps, num_eval_steps, producer = ( ncf_common.get_inputs(params)) params["num_users"], params["num_items"] = num_users, num_items producer.start() model_helpers.apply_clean(flags.FLAGS) batches_per_step = params["batches_per_step"] train_input_dataset, eval_input_dataset = _get_train_and_eval_data( producer, params) # It is required that for distributed training, the dataset must call # batch(). The parameter of batch() here is the number of replicas involed, # such that each replica evenly gets a slice of data. # drop_remainder = True, as we would like batch call to return a fixed shape # vs None, this prevents a expensive broadcast during weighted_loss train_input_dataset = train_input_dataset.batch(batches_per_step, drop_remainder=True) eval_input_dataset = eval_input_dataset.batch(batches_per_step, drop_remainder=True) time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps) per_epoch_callback = IncrementEpochCallback(producer) callbacks = [per_epoch_callback, time_callback] if FLAGS.early_stopping: early_stopping_callback = CustomEarlyStopping( "val_HR_METRIC", desired_value=FLAGS.hr_threshold) callbacks.append(early_stopping_callback) with distribution_utils.get_strategy_scope(strategy): keras_model = _get_keras_model(params) optimizer = tf.keras.optimizers.Adam( learning_rate=params["learning_rate"], beta_1=params["beta1"], beta_2=params["beta2"], epsilon=params["epsilon"]) if params["keras_use_ctl"]: loss_object = tf.keras.losses.SparseCategoricalCrossentropy( reduction=tf.keras.losses.Reduction.SUM, from_logits=True) train_input_iterator = strategy.make_dataset_iterator( train_input_dataset) eval_input_iterator = strategy.make_dataset_iterator( eval_input_dataset) @tf.function def train_step(): """Called once per step to train the model.""" def step_fn(features): """Computes loss and applied gradient per replica.""" with tf.GradientTape() as tape: softmax_logits = keras_model(features) labels = features[rconst.TRAIN_LABEL_KEY] loss = loss_object( labels, softmax_logits, sample_weight=features[rconst.VALID_POINT_MASK]) loss *= (1.0 / (batch_size * strategy.num_replicas_in_sync)) grads = tape.gradient(loss, keras_model.trainable_variables) # Converting gradients to dense form helps in perf on GPU for NCF grads = neumf_model.sparse_to_dense_grads( list(zip(grads, keras_model.trainable_variables))) optimizer.apply_gradients(grads) return loss per_replica_losses = strategy.experimental_run( step_fn, train_input_iterator) mean_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) return mean_loss @tf.function def eval_step(): """Called once per eval step to compute eval metrics.""" def step_fn(features): """Computes eval metrics per replica.""" softmax_logits = keras_model(features) in_top_k, metric_weights = metric_fn( softmax_logits, features[rconst.DUPLICATE_MASK], params) hr_sum = tf.reduce_sum(in_top_k * metric_weights) hr_count = tf.reduce_sum(metric_weights) return hr_sum, hr_count per_replica_hr_sum, per_replica_hr_count = ( strategy.experimental_run(step_fn, eval_input_iterator)) hr_sum = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None) hr_count = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_hr_count, axis=None) return hr_sum, hr_count time_callback.on_train_begin() for epoch in range(FLAGS.train_epochs): per_epoch_callback.on_epoch_begin(epoch) train_input_iterator.initialize() train_loss = 0 for step in range(num_train_steps): time_callback.on_batch_begin(step + epoch * num_train_steps) train_loss += train_step() time_callback.on_batch_end(step + epoch * num_train_steps) train_loss /= num_train_steps logging.info("Done training epoch %s, epoch loss=%s.", epoch + 1, train_loss) eval_input_iterator.initialize() hr_sum = 0 hr_count = 0 for _ in range(num_eval_steps): step_hr_sum, step_hr_count = eval_step() hr_sum += step_hr_sum hr_count += step_hr_count logging.info("Done eval epoch %s, hr=%s.", epoch + 1, hr_sum / hr_count) if (FLAGS.early_stopping and float(hr_sum / hr_count) > params["hr_threshold"]): break time_callback.on_train_end() eval_results = [None, hr_sum / hr_count] else: with distribution_utils.get_strategy_scope(strategy): keras_model.compile(optimizer=optimizer, run_eagerly=FLAGS.run_eagerly) history = keras_model.fit(train_input_dataset, epochs=FLAGS.train_epochs, callbacks=callbacks, validation_data=eval_input_dataset, validation_steps=num_eval_steps, verbose=2) logging.info("Training done. Start evaluating") eval_results = keras_model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) logging.info("Keras evaluation is done.") if history and history.history: train_history = history.history train_loss = train_history["loss"][-1] stats = build_stats(train_loss, eval_results, time_callback) return stats
def train_squad(strategy, input_meta_data, bert_config, custom_callbacks=None, run_eagerly=False, init_checkpoint=None, sub_model_export_name=None): """Run bert squad training.""" if strategy: logging.info( 'Training using customized training loop with distribution' ' strategy.') # Enables XLA in Session Config. Should not be set for TPU. keras_utils.set_session_config(FLAGS.enable_xla) performance.set_mixed_precision_policy(common_flags.dtype()) epochs = FLAGS.num_train_epochs num_train_examples = input_meta_data['train_data_size'] max_seq_length = input_meta_data['max_seq_length'] steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size) warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size) train_input_fn = get_dataset_fn(FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) def _get_squad_model(): """Get Squad model and optimizer.""" squad_model, core_model = bert_models.squad_model( bert_config, max_seq_length, hub_module_url=FLAGS.hub_module_url, hub_module_trainable=FLAGS.hub_module_trainable) optimizer = optimization.create_optimizer(FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps, FLAGS.end_lr, FLAGS.optimizer_type) squad_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return squad_model, core_model # Only when explicit_allreduce = True, post_allreduce_callbacks and # allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no # longer implicitly allreduce gradients, users manually allreduce gradient and # pass the allreduced grads_and_vars to apply_gradients(). # With explicit_allreduce = True, clip_by_global_norm is moved to after # allreduce. model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_squad_model, loss_fn=get_loss_fn(), model_dir=FLAGS.model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=FLAGS.steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, init_checkpoint=init_checkpoint or FLAGS.init_checkpoint, sub_model_export_name=sub_model_export_name, run_eagerly=run_eagerly, custom_callbacks=custom_callbacks, explicit_allreduce=FLAGS.explicit_allreduce, pre_allreduce_callbacks=[ model_training_utils.clip_by_global_norm_callback ], allreduce_bytes_per_pack=FLAGS.allreduce_bytes_per_pack)
def run(callbacks=None): keras_utils.set_session_config(enable_xla=FLAGS.enable_xla) params = config_factory.config_generator(FLAGS.model) params = params_dict.override_params_dict(params, FLAGS.config_file, is_strict=True) params = params_dict.override_params_dict(params, FLAGS.params_override, is_strict=True) params.override( { 'strategy_type': FLAGS.strategy_type, 'model_dir': FLAGS.model_dir, 'strategy_config': executor.strategy_flags_dict(), }, is_strict=False) # Make sure use_tpu and strategy_type are in sync. params.use_tpu = (params.strategy_type == 'tpu') if not params.use_tpu: params.override( { 'architecture': { 'use_bfloat16': False, }, 'norm_activation': { 'use_sync_bn': False, }, }, is_strict=True) params.validate() params.lock() pp = pprint.PrettyPrinter() params_str = pp.pformat(params.as_dict()) logging.info('Model Parameters: %s', params_str) train_input_fn = None eval_input_fn = None training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern if not training_file_pattern and not eval_file_pattern: raise ValueError( 'Must provide at least one of training_file_pattern and ' 'eval_file_pattern.') if training_file_pattern: # Use global batch size for single host. train_input_fn = input_reader.InputFn( file_pattern=training_file_pattern, params=params, mode=input_reader.ModeKeys.TRAIN, batch_size=params.train.batch_size) if eval_file_pattern: eval_input_fn = input_reader.InputFn( file_pattern=eval_file_pattern, params=params, mode=input_reader.ModeKeys.PREDICT_WITH_GT, batch_size=params.eval.batch_size, num_examples=params.eval.eval_samples) if callbacks is None: callbacks = [] if FLAGS.log_steps: callbacks.append( keras_utils.TimeHistory( batch_size=params.train.batch_size, log_steps=FLAGS.log_steps, )) return run_executor(params, FLAGS.mode, checkpoint_path=FLAGS.checkpoint_path, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, callbacks=callbacks)
def run_ncf(_): """Run NCF training and eval with Keras.""" keras_utils.set_session_config(enable_xla=FLAGS.enable_xla) if FLAGS.seed is not None: print("Setting tf seed") tf.random.set_seed(FLAGS.seed) params = ncf_common.parse_flags(FLAGS) model_helpers.apply_clean(flags.FLAGS) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, tpu_address=FLAGS.tpu) params["distribute_strategy"] = strategy if not keras_utils.is_v2_0() and strategy is not None: logging.error( "NCF Keras only works with distribution strategy in TF 2.0") return if (params["keras_use_ctl"] and (not keras_utils.is_v2_0() or strategy is None)): logging.error( "Custom training loop only works with tensorflow 2.0 and dist strat." ) return if params["use_tpu"] and not params["keras_use_ctl"]: logging.error( "Custom training loop must be used when using TPUStrategy.") return batch_size = params["batch_size"] time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps) callbacks = [time_callback] producer, input_meta_data = None, None generate_input_online = params["train_dataset_path"] is None if generate_input_online: # Start data producing thread. num_users, num_items, _, _, producer = ncf_common.get_inputs(params) producer.start() per_epoch_callback = IncrementEpochCallback(producer) callbacks.append(per_epoch_callback) else: assert params["eval_dataset_path"] and params["input_meta_data_path"] with tf.io.gfile.GFile(params["input_meta_data_path"], "rb") as reader: input_meta_data = json.loads(reader.read().decode("utf-8")) num_users = input_meta_data["num_users"] num_items = input_meta_data["num_items"] params["num_users"], params["num_items"] = num_users, num_items if FLAGS.early_stopping: early_stopping_callback = CustomEarlyStopping( "val_HR_METRIC", desired_value=FLAGS.hr_threshold) callbacks.append(early_stopping_callback) use_remote_tpu = params["use_tpu"] and FLAGS.tpu primary_cpu_task = tpu_lib.get_primary_cpu_task(use_remote_tpu) with tf.device(primary_cpu_task): (train_input_dataset, eval_input_dataset, num_train_steps, num_eval_steps) = \ (ncf_input_pipeline.create_ncf_input_data( params, producer, input_meta_data, strategy)) steps_per_epoch = None if generate_input_online else num_train_steps with distribution_utils.get_strategy_scope(strategy): keras_model = _get_keras_model(params) optimizer = tf.keras.optimizers.Adam( learning_rate=params["learning_rate"], beta_1=params["beta1"], beta_2=params["beta2"], epsilon=params["epsilon"]) if FLAGS.dtype == "fp16": optimizer = \ tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic")) if params["keras_use_ctl"]: train_loss, eval_results = run_ncf_custom_training( params, strategy, keras_model, optimizer, callbacks, train_input_dataset, eval_input_dataset, num_train_steps, num_eval_steps, generate_input_online=generate_input_online) else: # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if FLAGS.force_v2_in_keras_compile is not None: keras_model.compile(optimizer=optimizer, run_eagerly=FLAGS.run_eagerly, experimental_run_tf_function=FLAGS. force_v2_in_keras_compile) else: keras_model.compile(optimizer=optimizer, run_eagerly=FLAGS.run_eagerly) history = keras_model.fit(train_input_dataset, epochs=FLAGS.train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_data=eval_input_dataset, validation_steps=num_eval_steps, verbose=2) logging.info("Training done. Start evaluating") eval_loss_and_metrics = keras_model.evaluate( eval_input_dataset, steps=num_eval_steps, verbose=2) logging.info("Keras evaluation is done.") # Keras evaluate() API returns scalar loss and metric values from # evaluation as a list. Here, the returned list would contain # [evaluation loss, hr sum, hr count]. eval_hit_rate = eval_loss_and_metrics[ 1] / eval_loss_and_metrics[2] # Format evaluation result into [eval loss, eval hit accuracy]. eval_results = [eval_loss_and_metrics[0], eval_hit_rate] if history and history.history: train_history = history.history train_loss = train_history["loss"][-1] stats = build_stats(train_loss, eval_results, time_callback) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) dtype = flags_core.get_tf_dtype(flags_obj) if dtype == tf.bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) train_ds, test_ds = get_input_dataset(flags_obj, strategy) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) logging.info( "Training %d epochs, each epoch has %d steps, " "total steps: %d; Eval %d steps", train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory(flags_obj.batch_size, flags_obj.log_steps) with distribution_utils.get_strategy_scope(strategy): model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES, batch_size=flags_obj.batch_size, use_l2_regularizer=not flags_obj.single_l2_loss_op) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) optimizer = common.get_optimizer(lr_schedule) if flags_obj.fp16_implementation == 'graph_rewrite': if not flags_obj.use_tf_function: raise ValueError( '--fp16_implementation=graph_rewrite requires ' '--use_tf_function to be true') loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale) train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'training_accuracy', dtype=tf.float32) test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32) test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'test_accuracy', dtype=tf.float32) trainable_variables = model.trainable_variables def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs with tf.GradientTape() as tape: logits = model(images, training=True) prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(prediction_loss) * (1.0 / flags_obj.batch_size) num_replicas = tf.distribute.get_strategy( ).num_replicas_in_sync if flags_obj.single_l2_loss_op: filtered_variables = [ tf.reshape(v, (-1, )) for v in trainable_variables if 'bn' not in v.name ] l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) loss += (l2_loss / num_replicas) else: loss += (tf.reduce_sum(model.losses) / num_replicas) # Scale the loss if flags_obj.dtype == "fp16": loss = optimizer.get_scaled_loss(loss) grads = tape.gradient(loss, trainable_variables) # Unscale the grads if flags_obj.dtype == "fp16": grads = optimizer.get_unscaled_gradients(grads) optimizer.apply_gradients(zip(grads, trainable_variables)) train_loss.update_state(loss) training_accuracy.update_state(labels, logits) @tf.function def train_steps(iterator, steps): """Performs distributed training steps in a loop.""" for _ in tf.range(steps): strategy.experimental_run_v2(step_fn, args=(next(iterator), )) def train_single_step(iterator): if strategy: strategy.experimental_run_v2(step_fn, args=(next(iterator), )) else: return step_fn(next(iterator)) def test_step(iterator): """Evaluation StepFn.""" def step_fn(inputs): images, labels = inputs logits = model(images, training=False) loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(loss) * (1.0 / flags_obj.batch_size) test_loss.update_state(loss) test_accuracy.update_state(labels, logits) if strategy: strategy.experimental_run_v2(step_fn, args=(next(iterator), )) else: step_fn(next(iterator)) if flags_obj.use_tf_function: train_single_step = tf.function(train_single_step) test_step = tf.function(test_step) train_iter = iter(train_ds) time_callback.on_train_begin() for epoch in range(train_epochs): train_loss.reset_states() training_accuracy.reset_states() steps_in_current_epoch = 0 while steps_in_current_epoch < per_epoch_steps: time_callback.on_batch_begin(steps_in_current_epoch + epoch * per_epoch_steps) steps = _steps_to_run(steps_in_current_epoch, per_epoch_steps, steps_per_loop) if steps == 1: train_single_step(train_iter) else: # Converts steps to a Tensor to avoid tf.function retracing. train_steps(train_iter, tf.convert_to_tensor(steps, dtype=tf.int32)) time_callback.on_batch_end(steps_in_current_epoch + epoch * per_epoch_steps) steps_in_current_epoch += steps logging.info('Training loss: %s, accuracy: %s at epoch %d', train_loss.result().numpy(), training_accuracy.result().numpy(), epoch + 1) if (not flags_obj.skip_eval and (epoch + 1) % flags_obj.epochs_between_evals == 0): test_loss.reset_states() test_accuracy.reset_states() test_iter = iter(test_ds) for _ in range(eval_steps): test_step(test_iter) logging.info('Test loss: %s, accuracy: %s%% at epoch: %d', test_loss.result().numpy(), test_accuracy.result().numpy(), epoch + 1) time_callback.on_train_end() eval_result = None train_result = None if not flags_obj.skip_eval: eval_result = [ test_loss.result().numpy(), test_accuracy.result().numpy() ] train_result = [ train_loss.result().numpy(), training_accuracy.result().numpy() ] stats = build_stats(train_result, eval_result, time_callback) return stats
def train(self): """Trains the model.""" params = self.params flags_obj = self.flags_obj # Sets config options. # xla? keras_utils.set_session_config( enable_xla=flags_obj.enable_xla) _ensure_dir(flags_obj.model_dir) with distribution_utils.get_strategy_scope(self.distribution_strategy): model = transformer.create_model(params, is_train=True) opt = self._create_optimizer() # 恢复checkpoint current_step = 0 checkpoint = tf.train.Checkpoint(model=model, optimizer=opt) latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir) if latest_checkpoint: checkpoint.restore(latest_checkpoint) logging.info("Loaded checkpoint %s", latest_checkpoint) current_step = opt.iterations.numpy() #? # 分布式,均值 if params["use_ctl"]: train_loss_metric = tf.keras.metrics.Mean( "training_loss", dtype=tf.float32) else: # 模型训练的配置,包括优化器、LOSS等 model.compile(opt) # model结构 model.summary() if self.use_tpu: # Different from experimental_distribute_dataset, # experimental_distribute_datasets_from_function requires # per-replica/local batch size. params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync train_ds = ( self.distribution_strategy .experimental_distribute_datasets_from_function( lambda ctx: data_pipeline.train_input_fn(params, ctx))) else: # 平行句对 train_ds = data_pipeline.train_input_fn(params) map_data_fn = data_pipeline.map_data_for_transformer_fn train_ds = train_ds.map( map_data_fn, num_parallel_calls=params["num_parallel_calls"]) if params["use_ctl"]: train_ds_iterator = iter(train_ds) callbacks = self._create_callbacks(flags_obj.model_dir, 0, params) # TODO(b/139418525): Refactor the custom training loop logic. @tf.function def train_steps(iterator, steps): """Training steps function for TPU runs. Args: iterator: The input iterator of the training dataset. steps: An integer, the number of training steps. Returns: A float, the loss value. """ def _step_fn(inputs): """Per-replica step function.""" inputs, targets = inputs with tf.GradientTape() as tape: logits = model([inputs, targets], training=True) loss = metrics.transformer_loss(logits, targets, params["label_smoothing"], params["vocab_size"]) # Scales the loss, which results in using the average loss across all # of the replicas for backprop. scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync # De-dupes variables due to keras tracking issues. tvars = list({id(v): v for v in model.trainable_variables}.values()) grads = tape.gradient(scaled_loss, tvars) opt.apply_gradients(zip(grads, tvars)) # For reporting, the metric takes the mean of losses. train_loss_metric.update_state(loss) for _ in tf.range(steps): train_loss_metric.reset_states() self.distribution_strategy.experimental_run_v2( _step_fn, args=(next(iterator),)) cased_score, uncased_score = None, None cased_score_history, uncased_score_history = [], [] while current_step < flags_obj.train_steps: remaining_steps = flags_obj.train_steps - current_step train_steps_per_eval = ( remaining_steps if remaining_steps < flags_obj.steps_between_evals else flags_obj.steps_between_evals) current_iteration = current_step // flags_obj.steps_between_evals logging.info( "Start train iteration at global step:{}".format(current_step)) history = None # tpu使用的是上述train_steps函数 # gpu可直接用model.fit() if params["use_ctl"]: if not self.use_tpu: raise NotImplementedError( "Custom training loop on GPUs is not implemented.") # Runs training steps. train_steps(train_ds_iterator, tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32)) current_step += train_steps_per_eval train_loss = train_loss_metric.result().numpy().astype(float) logging.info("Train Step: %d/%d / loss = %s", current_step, flags_obj.train_steps, train_loss) checkpoint_name = checkpoint.save( os.path.join( flags_obj.model_dir, "ctl_step_{}.ckpt".format(current_step))) logging.info("Saved checkpoint to %s", checkpoint_name) else: if self.use_tpu: raise NotImplementedError( "Keras model.fit on TPUs is not implemented.") history = model.fit( train_ds, initial_epoch=current_iteration, epochs=current_iteration + 1, steps_per_epoch=train_steps_per_eval, callbacks=callbacks, # If TimeHistory is enabled, progress bar would be messy. Increase # the verbose level to get rid of it. verbose=(2 if flags_obj.enable_time_history else 1)) current_step += train_steps_per_eval logging.info("Train history: {}".format(history.history)) logging.info("End train iteration at global step:{}".format(current_step)) if (flags_obj.bleu_source and flags_obj.bleu_ref): # 区分大小写 uncased_score, cased_score = self.eval() cased_score_history.append([current_iteration + 1, cased_score]) uncased_score_history.append([current_iteration + 1, uncased_score]) stats = ({ "loss": train_loss } if history is None else misc.build_stats(history, callbacks)) if uncased_score and cased_score: stats["bleu_uncased"] = uncased_score stats["bleu_cased"] = cased_score stats["bleu_uncased_history"] = uncased_score_history stats["bleu_cased_history"] = cased_score_history return stats
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.config.list_physical_devices('GPU') else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: synthetic_util.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=cifar_preprocessing.HEIGHT, width=cifar_preprocessing.WIDTH, num_channels=cifar_preprocessing.NUM_CHANNELS, num_classes=cifar_preprocessing.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj), drop_remainder=True) else: synthetic_util.undo_set_up_synthetic_data() input_fn = cifar_preprocessing.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=cifar_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, # Setting drop_remainder to avoid the partial batch logic in normalization # layer, which triggers tf.where and leads to extra memory copy of input # sizes between host and GPU. drop_remainder=(not flags_obj.enable_get_next_as_optional)) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=cifar_preprocessing.parse_record) steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) lr_schedule = 0.1 if flags_obj.use_tensor_lr: initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128 lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay( boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE), values=[initial_learning_rate] + list(p[0] * initial_learning_rate for p in LR_SCHEDULE)) with strategy_scope: optimizer = common.get_optimizer(lr_schedule) model = resnet_cifar_model.resnet56( classes=cifar_preprocessing.NUM_CLASSES) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks(steps_per_epoch) if not flags_obj.use_tensor_lr: lr_callback = LearningRateBatchScheduler( schedule=learning_rate_schedule, batch_size=flags_obj.batch_size, steps_per_epoch=steps_per_epoch) callbacks.append(lr_callback) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == tf.float16: loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_float16', loss_scale=loss_scale) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) if not keras_utils.is_v2_0(): raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.') elif dtype == tf.bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) preprocessing_seed = 12345 # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, random_seed=preprocessing_seed, #addition num_workers=current_cluster_size(), #addition worker_ID=current_rank(), #addition tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, training_dataset_cache=flags_obj.training_dataset_cache, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, dtype=dtype, drop_remainder=drop_remainder) lr_schedule = 0.1 if flags_obj.use_tensor_lr: lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) # Build KungFu optimizer opt = common.get_optimizer(lr_schedule) # logging.info(opt.__dict__) optimizer = SynchronousSGDOptimizer(opt, reshape=False, use_locking=True) optimizer._hyper = opt._hyper # logging.info(optimizer.__dict__) if flags_obj.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) # TODO(hongkuny): Remove trivial model usage and move it to benchmark. if flags_obj.use_trivial_model: model = trivial_model.trivial_model(imagenet_preprocessing.NUM_CLASSES) else: model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. metrics = (['sparse_categorical_accuracy']) metrics.append('sparse_top_k_categorical_accuracy') if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=metrics, run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) else: model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=metrics, run_eagerly=flags_obj.run_eagerly) # adjust number of steps cluster_size = current_cluster_size() steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) steps_per_epoch = steps_per_epoch // cluster_size train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks(steps_per_epoch, current_rank(), cluster_size, common.learning_rate_schedule) # Broadcast variables for KungFu callbacks.append(BroadcastGlobalVariablesCallback()) # Checkpoint callback only on worker 0 if flags_obj.enable_checkpoint_and_export and current_rank() == 0: ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}') callbacks.append( tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True)) if flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) # Checkpoint only on 0th worker if flags_obj.enable_checkpoint_and_export and current_rank() == 0: if dtype == tf.bfloat16: logging.warning( "Keras model.save does not support bfloat16 dtype.") else: # Keras model.save assumes a float32 input designature. export_path = os.path.join(flags_obj.model_dir, 'saved_model') model.save(export_path, include_optimizer=False) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: common.set_gpu_thread_mode_and_count(flags_obj) if flags_obj.data_delay_prefetch: common.data_delay_prefetch() common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy( 'infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=num_workers, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, dtype=dtype, drop_remainder=drop_remainder) lr_schedule = 0.1 if flags_obj.use_tensor_lr: lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=LR_SCHEDULE[0][1], boundaries=list(p[1] for p in LR_SCHEDULE[1:]), multipliers=list(p[0] for p in LR_SCHEDULE), compute_lr_on_cpu=True) with strategy_scope: optimizer = common.get_optimizer(lr_schedule) if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) if flags_obj.use_trivial_model: model = trivial_model.trivial_model( imagenet_preprocessing.NUM_CLASSES, dtype) else: model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj. force_v2_in_keras_compile) else: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) callbacks = common.get_callbacks( learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train']) train_steps = (imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps // 15, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=1) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=1) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config() performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) if tf.config.list_physical_devices('GPU'): if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj. datasets_num_private_threads) common.set_cudnn_batchnorm_mode() data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.config.list_physical_devices('GPU') else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribute_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) if flags_obj.steps_per_loop is None: steps_per_loop = per_epoch_steps elif flags_obj.steps_per_loop > per_epoch_steps: steps_per_loop = per_epoch_steps logging.warn('Setting steps_per_loop to %d to respect epoch boundary.', steps_per_loop) else: steps_per_loop = flags_obj.steps_per_loop logging.info( 'Training %d epochs, each epoch has %d steps, ' 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory( flags_obj.batch_size, flags_obj.log_steps, logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None) with distribute_utils.get_strategy_scope(strategy): runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, per_epoch_steps) eval_interval = flags_obj.epochs_between_evals * per_epoch_steps checkpoint_interval = (steps_per_loop * 5 if flags_obj.enable_checkpoint_and_export else None) summary_interval = steps_per_loop if flags_obj.enable_tensorboard else None checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=flags_obj.model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval) resnet_controller = orbit.Controller( strategy=strategy, trainer=runnable, evaluator=runnable if not flags_obj.skip_eval else None, global_step=runnable.global_step, steps_per_loop=steps_per_loop, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, summary_dir=flags_obj.model_dir, eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval')) time_callback.on_train_begin() if not flags_obj.skip_eval: resnet_controller.train_and_evaluate(train_steps=per_epoch_steps * train_epochs, eval_steps=eval_steps, eval_interval=eval_interval) else: resnet_controller.train(steps=per_epoch_steps * train_epochs) time_callback.on_train_end() stats = build_stats(runnable, time_callback) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) train_ds, test_ds = get_input_dataset(flags_obj, strategy) train_steps, train_epochs, eval_steps = get_num_train_iterations(flags_obj) time_callback = keras_utils.TimeHistory(flags_obj.batch_size, flags_obj.log_steps) strategy_scope = distribution_utils.get_strategy_scope(strategy) with strategy_scope: model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES, batch_size=flags_obj.batch_size, use_l2_regularizer=not flags_obj.single_l2_loss_op) optimizer = tf.keras.optimizers.SGD( learning_rate=common.BASE_LEARNING_RATE, momentum=0.9, nesterov=True) if flags_obj.fp16_implementation == "graph_rewrite": if not flags_obj.use_tf_function: raise ValueError( "--fp16_implementation=graph_rewrite requires " "--use_tf_function to be true") loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale) training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'training_accuracy', dtype=tf.float32) test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32) test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'test_accuracy', dtype=tf.float32) trainable_variables = model.trainable_variables def train_step(train_ds_inputs): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs with tf.GradientTape() as tape: logits = model(images, training=True) prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(prediction_loss) * ( 1.0 / flags_obj.batch_size) num_replicas = tf.distribute.get_strategy( ).num_replicas_in_sync if flags_obj.single_l2_loss_op: filtered_variables = [ tf.reshape(v, (-1, )) for v in trainable_variables if 'bn' not in v.name ] l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) loss += (l2_loss / num_replicas) else: loss += (tf.reduce_sum(model.losses) / num_replicas) # Scale the loss if flags_obj.dtype == "fp16": loss = optimizer.get_scaled_loss(loss) grads = tape.gradient(loss, trainable_variables) # Unscale the grads if flags_obj.dtype == "fp16": grads = optimizer.get_unscaled_gradients(grads) optimizer.apply_gradients(zip(grads, trainable_variables)) training_accuracy.update_state(labels, logits) return loss if strategy: per_replica_losses = strategy.experimental_run_v2( step_fn, args=(train_ds_inputs, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) else: return step_fn(train_ds_inputs) def test_step(test_ds_inputs): """Evaluation StepFn.""" def step_fn(inputs): images, labels = inputs logits = model(images, training=False) loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(loss) * (1.0 / flags_obj.batch_size) test_loss.update_state(loss) test_accuracy.update_state(labels, logits) if strategy: strategy.experimental_run_v2(step_fn, args=(test_ds_inputs, )) else: step_fn(test_ds_inputs) if flags_obj.use_tf_function: train_step = tf.function(train_step) test_step = tf.function(test_step) time_callback.on_train_begin() for epoch in range(train_epochs): train_iter = iter(train_ds) total_loss = 0.0 training_accuracy.reset_states() for step in range(train_steps): optimizer.lr = common.learning_rate_schedule( epoch, step, train_steps, flags_obj.batch_size) time_callback.on_batch_begin(step + epoch * train_steps) total_loss += train_step(next(train_iter)) time_callback.on_batch_end(step + epoch * train_steps) train_loss = total_loss / train_steps logging.info('Training loss: %s, accuracy: %s%% at epoch: %d', train_loss.numpy(), training_accuracy.result().numpy(), epoch) if (not flags_obj.skip_eval and (epoch + 1) % flags_obj.epochs_between_evals == 0): test_loss.reset_states() test_accuracy.reset_states() test_iter = iter(test_ds) for _ in range(eval_steps): test_step(next(test_iter)) logging.info('Test loss: %s, accuracy: %s%% at epoch: %d', test_loss.result().numpy(), test_accuracy.result().numpy(), epoch) time_callback.on_train_end() eval_result = None train_result = None if not flags_obj.skip_eval: eval_result = [ test_loss.result().numpy(), test_accuracy.result().numpy() ] train_result = [ train_loss.numpy(), training_accuracy.result().numpy() ] stats = build_stats(train_result, eval_result, time_callback) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: common.set_gpu_thread_mode_and_count(flags_obj) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == tf.float16: loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_float16', loss_scale=loss_scale) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) if not keras_utils.is_v2_0(): raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.') elif dtype == tf.bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=num_workers, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, training_dataset_cache=flags_obj.training_dataset_cache, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, dtype=dtype, drop_remainder=drop_remainder) lr_schedule = 0.1 if flags_obj.use_tensor_lr: lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) with strategy_scope: optimizer = common.get_optimizer(lr_schedule) if flags_obj.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) # TODO(hongkuny): Remove trivial model usage and move it to benchmark. if flags_obj.use_trivial_model: model = trivial_model.trivial_model( imagenet_preprocessing.NUM_CLASSES) else: model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) else: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) steps_per_epoch = ( imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks(steps_per_epoch, common.learning_rate_schedule) if flags_obj.enable_checkpoint_and_export: ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}') callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True)) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = ( imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) if flags_obj.enable_checkpoint_and_export: if dtype == tf.bfloat16: logging.warning("Keras model.save does not support bfloat16 dtype.") else: # Keras model.save assumes a float32 input designature. export_path = os.path.join(flags_obj.model_dir, 'saved_model') model.save(export_path, include_optimizer=False) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats