def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == tf.float16: loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_float16', loss_scale=loss_scale) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) if not keras_utils.is_v2_0(): raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.') elif dtype == tf.bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) preprocessing_seed = 12345 # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, random_seed=preprocessing_seed, #addition num_workers=current_cluster_size(), #addition worker_ID=current_rank(), #addition tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, training_dataset_cache=flags_obj.training_dataset_cache, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, dtype=dtype, drop_remainder=drop_remainder) lr_schedule = 0.1 if flags_obj.use_tensor_lr: lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) # Build KungFu optimizer opt = common.get_optimizer(lr_schedule) # logging.info(opt.__dict__) optimizer = SynchronousSGDOptimizer(opt, reshape=False, use_locking=True) optimizer._hyper = opt._hyper # logging.info(optimizer.__dict__) if flags_obj.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) # TODO(hongkuny): Remove trivial model usage and move it to benchmark. if flags_obj.use_trivial_model: model = trivial_model.trivial_model(imagenet_preprocessing.NUM_CLASSES) else: model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. metrics = (['sparse_categorical_accuracy']) metrics.append('sparse_top_k_categorical_accuracy') if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=metrics, run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) else: model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=metrics, run_eagerly=flags_obj.run_eagerly) # adjust number of steps cluster_size = current_cluster_size() steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) steps_per_epoch = steps_per_epoch // cluster_size train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks(steps_per_epoch, current_rank(), cluster_size, common.learning_rate_schedule) # Broadcast variables for KungFu callbacks.append(BroadcastGlobalVariablesCallback()) # Checkpoint callback only on worker 0 if flags_obj.enable_checkpoint_and_export and current_rank() == 0: ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}') callbacks.append( tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True)) if flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) # Checkpoint only on 0th worker if flags_obj.enable_checkpoint_and_export and current_rank() == 0: if dtype == tf.bfloat16: logging.warning( "Keras model.save does not support bfloat16 dtype.") else: # Keras model.save assumes a float32 input designature. export_path = os.path.join(flags_obj.model_dir, 'saved_model') model.save(export_path, include_optimizer=False) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=cifar_preprocessing.HEIGHT, width=cifar_preprocessing.WIDTH, num_channels=cifar_preprocessing.NUM_CHANNELS, num_classes=cifar_preprocessing.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj), drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = cifar_preprocessing.input_fn #train_input_dataset = input_fn( # is_training=True, # data_dir=flags_obj.data_dir, # batch_size=flags_obj.batch_size, # num_epochs=flags_obj.train_epochs, # parse_record_fn=cifar_preprocessing.parse_record, # datasets_num_private_threads=flags_obj.datasets_num_private_threads, # dtype=dtype, # # Setting drop_remainder to avoid the partial batch logic in normalization # # layer, which triggers tf.where and leads to extra memory copy of input # # sizes between host and GPU. # drop_remainder=(not flags_obj.enable_get_next_as_optional)) # eval_input_dataset = None # if not flags_obj.skip_eval: # eval_input_dataset = input_fn( # is_training=False, # data_dir=flags_obj.data_dir, # batch_size=flags_obj.batch_size, # num_epochs=flags_obj.train_epochs, # parse_record_fn=cifar_preprocessing.parse_record) (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 y_train = tf.keras.utils.to_categorical(y_train, num_classes) y_test = tf.keras.utils.to_categorical(y_test, num_classes) # optimizer = common.get_optimizer() opt = tf.keras.optimizers.SGD(learning_rate=0.1) logging.info(opt.__dict__) optimizer = SynchronousSGDOptimizer(opt, use_locking=True) optimizer._hyper = opt._hyper logging.info(optimizer.__dict__) model = Conv4_model(x_train, num_classes) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='categorical_crossentropy', optimizer=optimizer, metrics=(['accuracy']), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) else: model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=(['accuracy']), run_eagerly=flags_obj.run_eagerly) cluster_size = current_cluster_size() steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) steps_per_epoch = steps_per_epoch // cluster_size train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks(steps_per_epoch, current_rank(), cluster_size, learning_rate_schedule) callbacks.append(BroadcastGlobalVariablesCallback()) if flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) # validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None tf.compat.v1.logging.info(x_train.shape) history = model.fit(x_train, y_train, batch_size=flags_obj.batch_size, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=(x_test, y_test), validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate((x_test, y_test), steps=num_eval_steps, verbose=2) stats = common.build_stats(history, eval_output, callbacks) return stats