def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_common.set_gpu_thread_mode_and_count(flags_obj) if flags_obj.data_delay_prefetch: keras_common.data_delay_prefetch() keras_common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, dtype=dtype, drop_remainder=drop_remainder) lr_schedule = 0.1 if flags_obj.use_tensor_lr: lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=LR_SCHEDULE[0][1], boundaries=list(p[1] for p in LR_SCHEDULE[1:]), multipliers=list(p[0] for p in LR_SCHEDULE), compute_lr_on_cpu=True) with strategy_scope: optimizer = keras_common.get_optimizer(lr_schedule) if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) if flags_obj.use_trivial_model: model = trivial_model.trivial_model( imagenet_preprocessing.NUM_CLASSES, dtype) else: model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype) model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) callbacks = keras_common.get_callbacks( learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train']) train_steps = ( imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = ( imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = keras_common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # Eager is default in tf 2.0 and should not be toggled if keras_common.is_v2_0(): keras_common.set_config_v2() else: config = keras_common.get_config_proto_v1() if flags_obj.enable_eager: tf.compat.v1.enable_eager_execution(config=config) else: sess = tf.Session(config=config) tf.keras.backend.set_session(sess) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_common.set_gpu_thread_mode_and_count(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy( 'infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster()) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_main.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn(is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, dtype=dtype, drop_remainder=drop_remainder) with strategy_scope: optimizer = keras_common.get_optimizer() if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj)) if flags_obj.enable_xla and not flags_obj.enable_eager: # TODO(b/129861005): Fix OOM issue in eager mode when setting # `batch_size` in keras.Input layer. if strategy and strategy.num_replicas_in_sync > 1: # TODO(b/129791381): Specify `per_replica_batch_size` value in # DistributionStrategy multi-replica case. per_replica_batch_size = None else: per_replica_batch_size = flags_obj.batch_size else: per_replica_batch_size = None if flags_obj.use_trivial_model: model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES) else: model = resnet_model.resnet50( num_classes=imagenet_main.NUM_CLASSES, dtype=dtype, batch_size=per_replica_batch_size) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['sparse_categorical_accuracy']) callbacks = keras_common.get_callbacks(learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) stats = keras_common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. """ if flags_obj.enable_eager: tf.enable_eager_execution() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') per_device_batch_size = distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)) # pylint: disable=protected-access if flags_obj.use_synthetic_data: input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj)) else: input_fn = imagenet_main.input_fn train_input_dataset = input_fn(is_training=True, data_dir=flags_obj.data_dir, batch_size=per_device_batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) eval_input_dataset = input_fn(is_training=False, data_dir=flags_obj.data_dir, batch_size=per_device_batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) strategy = distribution_utils.get_distribution_strategy( flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy) strategy_scope = keras_common.get_strategy_scope(strategy) with strategy_scope: optimizer = keras_common.get_optimizer() model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['sparse_categorical_accuracy']) time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: num_eval_steps = None validation_data = None model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=[time_callback, lr_callback, tensorboard_callback], validation_steps=num_eval_steps, validation_data=validation_data, verbose=1) if not flags_obj.skip_eval: model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=1)
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. """ if flags_obj.enable_eager: tf.enable_eager_execution() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError('dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # pylint: disable=protected-access if flags_obj.use_synthetic_data: input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj)) else: input_fn = imagenet_main.input_fn train_input_dataset = input_fn(is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) eval_input_dataset = input_fn(is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) strategy = distribution_utils.get_distribution_strategy( flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy) strategy_scope = keras_common.get_strategy_scope(strategy) with strategy_scope: optimizer = keras_common.get_optimizer() model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['sparse_categorical_accuracy']) time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=[ time_callback, lr_callback, tensorboard_callback ], validation_steps=num_eval_steps, validation_data=validation_data, verbose=1) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=1) stats = keras_common.build_stats(history, eval_output, time_callback) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ config = keras_common.get_config_proto() # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # Eager is default in tf 2.0 and should not be toggled if not keras_common.is_v2_0(): if flags_obj.enable_eager: tf.compat.v1.enable_eager_execution(config=config) else: sess = tf.Session(config=config) tf.keras.backend.set_session(sess) # TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready. dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=dtype) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_main.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype) eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, dtype=dtype) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus) strategy_scope = keras_common.get_strategy_scope(strategy) with strategy_scope: optimizer = keras_common.get_optimizer() if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj)) model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES, dtype=dtype) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['sparse_categorical_accuracy']) time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=[ time_callback, lr_callback, tensorboard_callback ], validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) stats = keras_common.build_stats(history, eval_output, time_callback) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ dtype = flags_core.get_tf_dtype(flags_obj) # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) train_ds, test_ds = get_input_dataset(flags_obj, strategy) train_steps, train_epochs, eval_steps = get_num_train_iterations(flags_obj) time_callback = keras_utils.TimeHistory(flags_obj.batch_size, flags_obj.log_steps) strategy_scope = distribution_utils.get_strategy_scope(strategy) with strategy_scope: model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, batch_size=flags_obj.batch_size) optimizer = tf.keras.optimizers.SGD( learning_rate=keras_common.BASE_LEARNING_RATE, momentum=0.9, nesterov=True) training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'training_accuracy', dtype=tf.float32) test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32) test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'test_accuracy', dtype=tf.float32) def train_step(train_ds_inputs): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs with tf.GradientTape() as tape: logits = model(images, training=True) prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss1 = tf.reduce_sum(prediction_loss) * ( 1.0 / flags_obj.batch_size) loss2 = (tf.reduce_sum(model.losses) / tf.distribute.get_strategy().num_replicas_in_sync) loss = loss1 + loss2 grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) training_accuracy.update_state(labels, logits) return loss if strategy: per_replica_losses = strategy.experimental_run_v2( step_fn, args=(train_ds_inputs, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) else: return step_fn(train_ds_inputs) def test_step(test_ds_inputs): """Evaluation StepFn.""" def step_fn(inputs): images, labels = inputs logits = model(images, training=False) loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(loss) * (1.0 / flags_obj.batch_size) test_loss.update_state(loss) test_accuracy.update_state(labels, logits) if strategy: strategy.experimental_run_v2(step_fn, args=(test_ds_inputs, )) else: step_fn(test_ds_inputs) if flags_obj.use_tf_function: train_step = tf.function(train_step) test_step = tf.function(test_step) time_callback.on_train_begin() for epoch in range(train_epochs): train_iter = iter(train_ds) total_loss = 0.0 training_accuracy.reset_states() for step in range(train_steps): optimizer.lr = keras_imagenet_main.learning_rate_schedule( epoch, step, train_steps, flags_obj.batch_size) time_callback.on_batch_begin(step + epoch * train_steps) total_loss += train_step(next(train_iter)) time_callback.on_batch_end(step + epoch * train_steps) train_loss = total_loss / train_steps logging.info('Training loss: %s, accuracy: %s%% at epoch: %d', train_loss.numpy(), training_accuracy.result().numpy(), epoch) if (not flags_obj.skip_eval and (epoch + 1) % flags_obj.epochs_between_evals == 0): test_loss.reset_states() test_accuracy.reset_states() test_iter = iter(test_ds) for _ in range(eval_steps): test_step(next(test_iter)) logging.info('Test loss: %s, accuracy: %s%% at epoch: %d', test_loss.result().numpy(), test_accuracy.result().numpy(), epoch) time_callback.on_train_end() eval_result = None train_result = None if not flags_obj.skip_eval: eval_result = [ test_loss.result().numpy(), test_accuracy.result().numpy() ] train_result = [ train_loss.numpy(), training_accuracy.result().numpy() ] stats = build_stats(train_result, eval_result, time_callback) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # Eager is default in tf 2.0 and should not be toggled if keras_common.is_v2_0(): keras_common.set_config_v2() else: config = keras_common.get_config_proto_v1() if flags_obj.enable_eager: tf.compat.v1.enable_eager_execution(config=config) else: sess = tf.Session(config=config) tf.keras.backend.set_session(sess) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_common.set_gpu_thread_mode_and_count(flags_obj) if flags_obj.data_prefetch_with_slack: keras_common.data_prefetch_with_slack() keras_common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_main.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, dtype=dtype, drop_remainder=drop_remainder) lr_schedule = 0.1 if flags_obj.use_tensor_lr: lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_main.NUM_IMAGES['train'], warmup_epochs=LR_SCHEDULE[0][1], boundaries=list(p[1] for p in LR_SCHEDULE[1:]), multipliers=list(p[0] for p in LR_SCHEDULE), compute_lr_on_cpu=True) with strategy_scope: optimizer = keras_common.get_optimizer(lr_schedule) if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj)) if flags_obj.enable_xla and not flags_obj.enable_eager: # TODO(b/129861005): Fix OOM issue in eager mode when setting # `batch_size` in keras.Input layer. if strategy and strategy.num_replicas_in_sync > 1: # TODO(b/129791381): Specify `input_layer_batch_size` value in # DistributionStrategy multi-replica case. input_layer_batch_size = None else: input_layer_batch_size = flags_obj.batch_size else: input_layer_batch_size = None if flags_obj.use_trivial_model: model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES, dtype) else: model = resnet_model.resnet50( num_classes=imagenet_main.NUM_CLASSES, dtype=dtype, batch_size=input_layer_batch_size) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), cloning=flags_obj.clone_model_in_keras_dist_strat) callbacks = keras_common.get_callbacks( learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) stats = keras_common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. """ if flags_obj.enable_eager: tf.enable_eager_execution() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError('dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # pylint: disable=protected-access if flags_obj.use_synthetic_data: input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj)) else: input_fn = imagenet_main.input_fn train_input_dataset = input_fn(is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) eval_input_dataset = input_fn(is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) strategy = distribution_utils.get_distribution_strategy( num_gpus=flags_obj.num_gpus, turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy) strategy_scope = keras_common.get_strategy_scope(strategy) with strategy_scope: optimizer = keras_common.get_optimizer() model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['sparse_categorical_accuracy']) time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=[ time_callback, lr_callback, tensorboard_callback ], validation_steps=num_eval_steps, validation_data=validation_data, verbose=1) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=1) stats = keras_common.build_stats(history, eval_output, time_callback) return stats