def main(_): # Users should always run this script under TF 2.x assert tf.version.VERSION.startswith('2.') with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) if FLAGS.mode == 'export_only': export_squad(FLAGS.model_export_path, input_meta_data) return # Configures cluster spec for multi-worker distribution strategy. if FLAGS.num_gpus > 0: _ = distribution_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, all_reduce_alg=FLAGS.all_reduce_alg, tpu_address=FLAGS.tpu) if FLAGS.mode in ('train', 'train_and_predict'): train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly) if FLAGS.mode in ('predict', 'train_and_predict'): predict_squad(strategy, input_meta_data)
def main(_): with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) if FLAGS.mode == 'export_only': export_squad(FLAGS.model_export_path, input_meta_data) return # Configures cluster spec for multi-worker distribution strategy. if FLAGS.num_gpus > 0: _ = distribution_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, all_reduce_alg=FLAGS.all_reduce_alg, tpu_address=FLAGS.tpu) if 'train' in FLAGS.mode: train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly) if 'predict' in FLAGS.mode: predict_squad(strategy, input_meta_data) if 'eval' in FLAGS.mode: eval_metrics = eval_squad(strategy, input_meta_data) f1_score = eval_metrics['final_f1'] logging.info('SQuAD eval F1-score: %f', f1_score) summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval') summary_writer = tf.summary.create_file_writer(summary_dir) with summary_writer.as_default(): # TODO(lehou): write to the correct step number. tf.summary.scalar('F1-score', f1_score, step=0) summary_writer.flush() # Also write eval_metrics to json file. squad_lib_sp.write_to_json_files( eval_metrics, os.path.join(summary_dir, 'eval_metrics.json')) time.sleep(60)
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. NotImplementedError: If some features are not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) performance.set_mixed_precision_policy( flags_core.get_tf_dtype(flags_obj), flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.config.list_physical_devices('GPU') else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla # Current resnet_model.resnet50 input format is always channel-last. # We use keras_application mobilenet model which input format is depends on # the keras beckend image data format. # This use_keras_image_data_format flags indicates whether image preprocessor # output format should be same as the keras backend image data format or just # channel-last format. use_keras_image_data_format = (flags_obj.model == 'mobilenet') train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, training_dataset_cache=flags_obj.training_dataset_cache, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), dtype=dtype, drop_remainder=drop_remainder) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) steps_per_epoch = ( imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) with strategy_scope: if flags_obj.optimizer == 'resnet50_default': optimizer = common.get_optimizer(lr_schedule) elif flags_obj.optimizer == 'mobilenet_default': initial_learning_rate = \ flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size optimizer = tf.keras.optimizers.SGD( learning_rate=tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay, decay_rate=flags_obj.lr_decay_factor, staircase=True), momentum=0.9) if flags_obj.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) # TODO(hongkuny): Remove trivial model usage and move it to benchmark. if flags_obj.use_trivial_model: model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES) elif flags_obj.model == 'resnet50_v1.5': model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES) elif flags_obj.model == 'mobilenet': # TODO(kimjaehong): Remove layers attribute when minimum TF version # support 2.0 layers by default. model = tf.keras.applications.mobilenet.MobileNet( weights=None, classes=imagenet_preprocessing.NUM_CLASSES, layers=tf.keras.layers) if flags_obj.pretrained_filepath: model.load_weights(flags_obj.pretrained_filepath) if flags_obj.pruning_method == 'polynomial_decay': if dtype != tf.float32: raise NotImplementedError( 'Pruning is currently only supported on dtype=tf.float32.') pruning_params = { 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay( initial_sparsity=flags_obj.pruning_initial_sparsity, final_sparsity=flags_obj.pruning_final_sparsity, begin_step=flags_obj.pruning_begin_step, end_step=flags_obj.pruning_end_step, frequency=flags_obj.pruning_frequency), } model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params) elif flags_obj.pruning_method: raise NotImplementedError( 'Only polynomial_decay is currently supported.') model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks( pruning_method=flags_obj.pruning_method, enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export, model_dir=flags_obj.model_dir) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = ( imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if flags_obj.pruning_method: model = tfmot.sparsity.keras.strip_pruning(model) if flags_obj.enable_checkpoint_and_export: if dtype == tf.bfloat16: logging.warning('Keras model.save does not support bfloat16 dtype.') else: # Keras model.save assumes a float32 input designature. export_path = os.path.join(flags_obj.model_dir, 'saved_model') model.save(export_path, include_optimizer=False) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: common.set_gpu_thread_mode_and_count(flags_obj) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == tf.float16: loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_float16', loss_scale=loss_scale) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) if not keras_utils.is_v2_0(): raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.') elif dtype == tf.bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=num_workers, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, training_dataset_cache=flags_obj.training_dataset_cache, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, dtype=dtype, drop_remainder=drop_remainder) lr_schedule = 0.1 if flags_obj.use_tensor_lr: lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) with strategy_scope: optimizer = common.get_optimizer(lr_schedule) if flags_obj.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) # TODO(hongkuny): Remove trivial model usage and move it to benchmark. if flags_obj.use_trivial_model: model = trivial_model.trivial_model( imagenet_preprocessing.NUM_CLASSES) else: model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) else: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) steps_per_epoch = ( imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks(steps_per_epoch, common.learning_rate_schedule) if flags_obj.enable_checkpoint_and_export: ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}') callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True)) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = ( imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) if flags_obj.enable_checkpoint_and_export: if dtype == tf.bfloat16: logging.warning("Keras model.save does not support bfloat16 dtype.") else: # Keras model.save assumes a float32 input designature. export_path = os.path.join(flags_obj.model_dir, 'saved_model') model.save(export_path, include_optimizer=False) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def train_and_eval( params: base_configs.ExperimentConfig, strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]: """Runs the train and eval path using compile/fit.""" logging.info('Running train and eval.') distribution_utils.configure_cluster( params.runtime.worker_hosts, params.runtime.task_index) # Note: for TPUs, strategy and scope should be created before the dataset strategy = strategy_override or distribution_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) strategy_scope = distribution_utils.get_strategy_scope(strategy) logging.info('Detected %d devices.', strategy.num_replicas_in_sync if strategy else 1) label_smoothing = params.model.loss.label_smoothing one_hot = label_smoothing and label_smoothing > 0 builders = _get_dataset_builders(params, strategy, one_hot) datasets = [builder.build(strategy) if builder else None for builder in builders] # Unpack datasets and builders based on train/val/test splits train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking train_dataset, validation_dataset = datasets train_epochs = params.train.epochs train_steps = params.train.steps or train_builder.num_steps validation_steps = params.evaluation.steps or validation_builder.num_steps initialize(params, train_builder) logging.info('Global batch size: %d', train_builder.global_batch_size) with strategy_scope: model_params = params.model.model_params.as_dict() model = get_models()[params.model.name](**model_params) learning_rate = optimizer_factory.build_learning_rate( params=params.model.learning_rate, batch_size=train_builder.global_batch_size, train_epochs=train_epochs, train_steps=train_steps) optimizer = optimizer_factory.build_optimizer( optimizer_name=params.model.optimizer.name, base_learning_rate=learning_rate, params=params.model.optimizer.as_dict(), model=model) metrics_map = _get_metrics(one_hot) metrics = [metrics_map[metric] for metric in params.train.metrics] steps_per_loop = train_steps if params.train.set_epoch_loop else 1 if one_hot: loss_obj = tf.keras.losses.CategoricalCrossentropy( label_smoothing=params.model.loss.label_smoothing) else: loss_obj = tf.keras.losses.SparseCategoricalCrossentropy() model.compile(optimizer=optimizer, loss=loss_obj, metrics=metrics, experimental_steps_per_execution=steps_per_loop) initial_epoch = 0 if params.train.resume_checkpoint: initial_epoch = resume_from_checkpoint(model=model, model_dir=params.model_dir, train_steps=train_steps) callbacks = custom_callbacks.get_callbacks( model_checkpoint=params.train.callbacks.enable_checkpoint_and_export, include_tensorboard=params.train.callbacks.enable_tensorboard, time_history=params.train.callbacks.enable_time_history, track_lr=params.train.tensorboard.track_lr, write_model_weights=params.train.tensorboard.write_model_weights, initial_step=initial_epoch * train_steps, batch_size=train_builder.global_batch_size, log_steps=params.train.time_history.log_steps, model_dir=params.model_dir) serialize_config(params=params, model_dir=params.model_dir) if params.evaluation.skip_eval: validation_kwargs = {} else: validation_kwargs = { 'validation_data': validation_dataset, 'validation_steps': validation_steps, 'validation_freq': params.evaluation.epochs_between_evals, } history = model.fit( train_dataset, epochs=train_epochs, steps_per_epoch=train_steps, initial_epoch=initial_epoch, callbacks=callbacks, verbose=2, **validation_kwargs) validation_output = None if not params.evaluation.skip_eval: validation_output = model.evaluate( validation_dataset, steps=validation_steps, verbose=2) # TODO(dankondratyuk): eval and save final test accuracy stats = common.build_stats(history, validation_output, callbacks) return stats
def run_predict(flags_obj, datasets_override=None, strategy_override=None): keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=1, datasets_num_private_threads=flags_obj.datasets_num_private_threads) common.set_cudnn_batchnorm_mode() performance.set_mixed_precision_policy( flags_core.get_tf_dtype(flags_obj), flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=1, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) distribution_utils.undo_set_up_synthetic_data() train_input_dataset, eval_input_dataset, tr_dataset, te_dataset = setup_datasets(flags_obj, shuffle=False, save_labels=True) pred_input_dataset, pred_dataset = eval_input_dataset, te_dataset with strategy_scope: model = build_model(imagenet_preprocessing.NUM_CLASSES, mode='resnet50_features', save_labels=True) load_path = GB_OPTIONS.pretrained_filepath if load_path is None: load_path = GB_OPTIONS.checkpoint_folder latest = tf.train.latest_checkpoint(load_path) print(latest) model.load_weights(latest) num_eval_steps = imagenet_preprocessing.NUM_IMAGES['validation'] // GB_OPTIONS.batch_size pred = model.predict( pred_input_dataset, batch_size = GB_OPTIONS.batch_size, steps = num_eval_steps ) np.save(GB_OPTIONS.out_npys_folder+'out_X', pred[0]) np.save(GB_OPTIONS.out_npys_folder+'out_labels', pred[1]) np.save(GB_OPTIONS.out_npys_folder+'out_ori_labels', pred[2]) return 'good'
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=cifar_preprocessing.HEIGHT, width=cifar_preprocessing.WIDTH, num_channels=cifar_preprocessing.NUM_CHANNELS, num_classes=cifar_preprocessing.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj), drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = cifar_preprocessing.input_fn #train_input_dataset = input_fn( # is_training=True, # data_dir=flags_obj.data_dir, # batch_size=flags_obj.batch_size, # num_epochs=flags_obj.train_epochs, # parse_record_fn=cifar_preprocessing.parse_record, # datasets_num_private_threads=flags_obj.datasets_num_private_threads, # dtype=dtype, # # Setting drop_remainder to avoid the partial batch logic in normalization # # layer, which triggers tf.where and leads to extra memory copy of input # # sizes between host and GPU. # drop_remainder=(not flags_obj.enable_get_next_as_optional)) # eval_input_dataset = None # if not flags_obj.skip_eval: # eval_input_dataset = input_fn( # is_training=False, # data_dir=flags_obj.data_dir, # batch_size=flags_obj.batch_size, # num_epochs=flags_obj.train_epochs, # parse_record_fn=cifar_preprocessing.parse_record) (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 y_train = tf.keras.utils.to_categorical(y_train, num_classes) y_test = tf.keras.utils.to_categorical(y_test, num_classes) # optimizer = common.get_optimizer() opt = tf.keras.optimizers.SGD(learning_rate=0.1) logging.info(opt.__dict__) optimizer = SynchronousSGDOptimizer(opt, use_locking=True) optimizer._hyper = opt._hyper logging.info(optimizer.__dict__) model = Conv4_model(x_train, num_classes) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='categorical_crossentropy', optimizer=optimizer, metrics=(['accuracy']), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) else: model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=(['accuracy']), run_eagerly=flags_obj.run_eagerly) cluster_size = current_cluster_size() steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) steps_per_epoch = steps_per_epoch // cluster_size train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks(steps_per_epoch, current_rank(), cluster_size, learning_rate_schedule) callbacks.append(BroadcastGlobalVariablesCallback()) if flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) # validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None tf.compat.v1.logging.info(x_train.shape) history = model.fit(x_train, y_train, batch_size=flags_obj.batch_size, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=(x_test, y_test), validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate((x_test, y_test), steps=num_eval_steps, verbose=2) stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) train_ds, test_ds = get_input_dataset(flags_obj, strategy) train_steps, train_epochs, eval_steps = get_num_train_iterations(flags_obj) time_callback = keras_utils.TimeHistory(flags_obj.batch_size, flags_obj.log_steps) strategy_scope = distribution_utils.get_strategy_scope(strategy) with strategy_scope: model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES, batch_size=flags_obj.batch_size, use_l2_regularizer=not flags_obj.single_l2_loss_op) optimizer = tf.keras.optimizers.SGD( learning_rate=common.BASE_LEARNING_RATE, momentum=0.9, nesterov=True) if flags_obj.fp16_implementation == "graph_rewrite": if not flags_obj.use_tf_function: raise ValueError( "--fp16_implementation=graph_rewrite requires " "--use_tf_function to be true") loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale) training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'training_accuracy', dtype=tf.float32) test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32) test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'test_accuracy', dtype=tf.float32) trainable_variables = model.trainable_variables def train_step(train_ds_inputs): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs with tf.GradientTape() as tape: logits = model(images, training=True) prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(prediction_loss) * ( 1.0 / flags_obj.batch_size) num_replicas = tf.distribute.get_strategy( ).num_replicas_in_sync if flags_obj.single_l2_loss_op: filtered_variables = [ tf.reshape(v, (-1, )) for v in trainable_variables if 'bn' not in v.name ] l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) loss += (l2_loss / num_replicas) else: loss += (tf.reduce_sum(model.losses) / num_replicas) # Scale the loss if flags_obj.dtype == "fp16": loss = optimizer.get_scaled_loss(loss) grads = tape.gradient(loss, trainable_variables) # Unscale the grads if flags_obj.dtype == "fp16": grads = optimizer.get_unscaled_gradients(grads) optimizer.apply_gradients(zip(grads, trainable_variables)) training_accuracy.update_state(labels, logits) return loss if strategy: per_replica_losses = strategy.experimental_run_v2( step_fn, args=(train_ds_inputs, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) else: return step_fn(train_ds_inputs) def test_step(test_ds_inputs): """Evaluation StepFn.""" def step_fn(inputs): images, labels = inputs logits = model(images, training=False) loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(loss) * (1.0 / flags_obj.batch_size) test_loss.update_state(loss) test_accuracy.update_state(labels, logits) if strategy: strategy.experimental_run_v2(step_fn, args=(test_ds_inputs, )) else: step_fn(test_ds_inputs) if flags_obj.use_tf_function: train_step = tf.function(train_step) test_step = tf.function(test_step) time_callback.on_train_begin() for epoch in range(train_epochs): train_iter = iter(train_ds) total_loss = 0.0 training_accuracy.reset_states() for step in range(train_steps): optimizer.lr = common.learning_rate_schedule( epoch, step, train_steps, flags_obj.batch_size) time_callback.on_batch_begin(step + epoch * train_steps) total_loss += train_step(next(train_iter)) time_callback.on_batch_end(step + epoch * train_steps) train_loss = total_loss / train_steps logging.info('Training loss: %s, accuracy: %s%% at epoch: %d', train_loss.numpy(), training_accuracy.result().numpy(), epoch) if (not flags_obj.skip_eval and (epoch + 1) % flags_obj.epochs_between_evals == 0): test_loss.reset_states() test_accuracy.reset_states() test_iter = iter(test_ds) for _ in range(eval_steps): test_step(next(test_iter)) logging.info('Test loss: %s, accuracy: %s%% at epoch: %d', test_loss.result().numpy(), test_accuracy.result().numpy(), epoch) time_callback.on_train_end() eval_result = None train_result = None if not flags_obj.skip_eval: eval_result = [ test_loss.result().numpy(), test_accuracy.result().numpy() ] train_result = [ train_loss.numpy(), training_accuracy.result().numpy() ] stats = build_stats(train_result, eval_result, time_callback) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) dtype = flags_core.get_tf_dtype(flags_obj) if dtype == tf.bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) train_ds, test_ds = get_input_dataset(flags_obj, strategy) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) logging.info( "Training %d epochs, each epoch has %d steps, " "total steps: %d; Eval %d steps", train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory(flags_obj.batch_size, flags_obj.log_steps) with distribution_utils.get_strategy_scope(strategy): model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES, batch_size=flags_obj.batch_size, use_l2_regularizer=not flags_obj.single_l2_loss_op) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) optimizer = common.get_optimizer(lr_schedule) if flags_obj.fp16_implementation == 'graph_rewrite': if not flags_obj.use_tf_function: raise ValueError( '--fp16_implementation=graph_rewrite requires ' '--use_tf_function to be true') loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale) train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'training_accuracy', dtype=tf.float32) test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32) test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'test_accuracy', dtype=tf.float32) trainable_variables = model.trainable_variables def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs with tf.GradientTape() as tape: logits = model(images, training=True) prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(prediction_loss) * (1.0 / flags_obj.batch_size) num_replicas = tf.distribute.get_strategy( ).num_replicas_in_sync if flags_obj.single_l2_loss_op: filtered_variables = [ tf.reshape(v, (-1, )) for v in trainable_variables if 'bn' not in v.name ] l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) loss += (l2_loss / num_replicas) else: loss += (tf.reduce_sum(model.losses) / num_replicas) # Scale the loss if flags_obj.dtype == "fp16": loss = optimizer.get_scaled_loss(loss) grads = tape.gradient(loss, trainable_variables) # Unscale the grads if flags_obj.dtype == "fp16": grads = optimizer.get_unscaled_gradients(grads) optimizer.apply_gradients(zip(grads, trainable_variables)) train_loss.update_state(loss) training_accuracy.update_state(labels, logits) @tf.function def train_steps(iterator, steps): """Performs distributed training steps in a loop.""" for _ in tf.range(steps): strategy.experimental_run_v2(step_fn, args=(next(iterator), )) def train_single_step(iterator): if strategy: strategy.experimental_run_v2(step_fn, args=(next(iterator), )) else: return step_fn(next(iterator)) def test_step(iterator): """Evaluation StepFn.""" def step_fn(inputs): images, labels = inputs logits = model(images, training=False) loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(loss) * (1.0 / flags_obj.batch_size) test_loss.update_state(loss) test_accuracy.update_state(labels, logits) if strategy: strategy.experimental_run_v2(step_fn, args=(next(iterator), )) else: step_fn(next(iterator)) if flags_obj.use_tf_function: train_single_step = tf.function(train_single_step) test_step = tf.function(test_step) train_iter = iter(train_ds) time_callback.on_train_begin() for epoch in range(train_epochs): train_loss.reset_states() training_accuracy.reset_states() steps_in_current_epoch = 0 while steps_in_current_epoch < per_epoch_steps: time_callback.on_batch_begin(steps_in_current_epoch + epoch * per_epoch_steps) steps = _steps_to_run(steps_in_current_epoch, per_epoch_steps, steps_per_loop) if steps == 1: train_single_step(train_iter) else: # Converts steps to a Tensor to avoid tf.function retracing. train_steps(train_iter, tf.convert_to_tensor(steps, dtype=tf.int32)) time_callback.on_batch_end(steps_in_current_epoch + epoch * per_epoch_steps) steps_in_current_epoch += steps logging.info('Training loss: %s, accuracy: %s at epoch %d', train_loss.result().numpy(), training_accuracy.result().numpy(), epoch + 1) if (not flags_obj.skip_eval and (epoch + 1) % flags_obj.epochs_between_evals == 0): test_loss.reset_states() test_accuracy.reset_states() test_iter = iter(test_ds) for _ in range(eval_steps): test_step(test_iter) logging.info('Test loss: %s, accuracy: %s%% at epoch: %d', test_loss.result().numpy(), test_accuracy.result().numpy(), epoch + 1) time_callback.on_train_end() eval_result = None train_result = None if not flags_obj.skip_eval: eval_result = [ test_loss.result().numpy(), test_accuracy.result().numpy() ] train_result = [ train_loss.result().numpy(), training_accuracy.result().numpy() ] stats = build_stats(train_result, eval_result, time_callback) return stats
def resnet_main( flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. Dict of results of the run. Contains the keys `eval_results` and `train_hooks`. `eval_results` contains accuracy (top_1) and accuracy_top_5. `train_hooks` is a list the instances of hooks used during training. """ model_helpers.apply_clean(flags.FLAGS) # Ensures flag override logic is only executed if explicitly triggered. if flags_obj.tf_gpu_thread_mode: override_flags_and_set_envars_for_gpu_thread_pool(flags_obj) # Configures cluster spec for distribution strategy. num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) # Creates session config. allow_soft_placement = True, is required for # multi-GPU and is not harmful for other modes. session_config = tf.compat.v1.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_core.get_num_gpus(flags_obj), num_workers=num_workers, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) # Creates a `RunConfig` that checkpoints every 24 hours which essentially # results in checkpoints determined only by `epochs_between_evals`. run_config = tf.estimator.RunConfig( train_distribute=distribution_strategy, session_config=session_config, save_checkpoints_secs=60*60*24, save_checkpoints_steps=None) # Initializes model with all but the dense layer from pretrained ResNet. if flags_obj.pretrained_model_checkpoint_path is not None: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start='^(?!.*dense)') else: warm_start_settings = None classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, warm_start_from=warm_start_settings, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj), 'fine_tune': flags_obj.fine_tune, 'num_workers': num_workers, }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, 'num_workers': num_workers, } if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('resnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks( flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) def input_fn_train(num_epochs, input_context=None): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_replica_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, dtype=flags_core.get_tf_dtype(flags_obj), datasets_num_private_threads=flags_obj.datasets_num_private_threads, num_parallel_batches=flags_obj.datasets_num_parallel_batches, input_context=input_context) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_replica_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_tf_dtype(flags_obj)) train_epochs = (0 if flags_obj.eval_only or not flags_obj.train_epochs else flags_obj.train_epochs) use_train_and_evaluate = flags_obj.use_train_and_evaluate or num_workers > 1 if use_train_and_evaluate: train_spec = tf.estimator.TrainSpec( input_fn=lambda input_context=None: input_fn_train( train_epochs, input_context=input_context), hooks=train_hooks, max_steps=flags_obj.max_train_steps) eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval) tf.compat.v1.logging.info('Starting to train and evaluate.') tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec) # tf.estimator.train_and_evalute doesn't return anything in multi-worker # case. return {} else: if train_epochs == 0: # If --eval_only is set, perform a single loop with zero train epochs. schedule, n_loops = [0], 1 else: # Compute the number of times to loop while training. All but the last # pass will train for `epochs_between_evals` epochs, while the last will # train for the number needed to reach `training_epochs`. For instance if # train_epochs = 25 and epochs_between_evals = 10 # schedule will be set to [10, 10, 5]. That is to say, the loop will: # Train for 10 epochs and then evaluate. # Train for another 10 epochs and then evaluate. # Train for a final 5 epochs (to reach 25 epochs) and then evaluate. n_loops = math.ceil(train_epochs / flags_obj.epochs_between_evals) schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))] schedule[-1] = train_epochs - sum(schedule[:-1]) # over counting. for cycle_index, num_train_epochs in enumerate(schedule): tf.compat.v1.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) if num_train_epochs: # Since we are calling classifier.train immediately in each loop, the # value of num_train_epochs in the lambda function will not be changed # before it is used. So it is safe to ignore the pylint error here # pylint: disable=cell-var-from-loop classifier.train( input_fn=lambda input_context=None: input_fn_train( num_train_epochs, input_context=input_context), hooks=train_hooks, max_steps=flags_obj.max_train_steps) # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, # which will iterate forever. Passing steps=flags_obj.max_train_steps # allows the eval (which is generally unimportant in those circumstances) # to terminate. Note that eval will run for max_train_steps each loop, # regardless of the global_step count. tf.compat.v1.logging.info('Starting to evaluate.') eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold( flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. export_dtype = flags_core.get_tf_dtype(flags_obj) if flags_obj.image_bytes_as_serving_input: input_receiver_fn = functools.partial( image_bytes_serving_input_fn, shape, dtype=export_dtype) else: input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size, dtype=export_dtype) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn, strip_default_attrs=True) stats = {} stats['eval_results'] = eval_results stats['train_hooks'] = train_hooks return stats
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: common.set_gpu_thread_mode_and_count(flags_obj) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=cifar_preprocessing.HEIGHT, width=cifar_preprocessing.WIDTH, num_channels=cifar_preprocessing.NUM_CHANNELS, num_classes=cifar_preprocessing.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj), drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = cifar_preprocessing.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, # Setting drop_remainder to avoid the partial batch logic in normalization # layer, which triggers tf.where and leads to extra memory copy of input # sizes between host and GPU. drop_remainder=(not flags_obj.enable_get_next_as_optional)) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_preprocessing.parse_record) with strategy_scope: optimizer = common.get_optimizer(learning_rate=0.1 * hvd.size()) # Horovod: add Horovod DistributedOptimizer. optimizer = hvd.DistributedOptimizer(optimizer) model = resnet_cifar_model.resnet56( classes=cifar_preprocessing.NUM_CLASSES) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='categorical_crossentropy', optimizer=optimizer, metrics=(['categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), #run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=False) else: model.compile( loss='categorical_crossentropy', optimizer=optimizer, metrics=(['categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), #run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=False) callbacks = common.get_callbacks(learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train']) train_steps = cifar_preprocessing.NUM_IMAGES[ 'train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() callbacks = [ # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. hvd.callbacks.BroadcastGlobalVariablesCallback(0), # Horovod: average metrics among workers at the end of every epoch. # # Note: This callback must be in the list before the ReduceLROnPlateau, # TensorBoard or other metrics-based callbacks. hvd.callbacks.MetricAverageCallback(), # Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during # the first three epochs. See https://arxiv.org/abs/1706.02677 for details. hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=3, verbose=1), ] # Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them. if hvd.rank() == 0: callbacks.append( tf.keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5')) # Horovod: write logs on worker 0. verbose = 1 if hvd.rank() == 0 else 0 history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=verbose) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # Eager is default in tf 2.0 and should not be toggled if keras_common.is_v2_0(): keras_common.set_config_v2() else: config = keras_common.get_config_proto_v1() if flags_obj.enable_eager: tf.compat.v1.enable_eager_execution(config=config) else: sess = tf.Session(config=config) tf.keras.backend.set_session(sess) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_common.set_gpu_thread_mode_and_count(flags_obj) if flags_obj.data_prefetch_with_slack: keras_common.data_prefetch_with_slack() keras_common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_main.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, dtype=dtype, drop_remainder=drop_remainder) lr_schedule = 0.1 if flags_obj.use_tensor_lr: lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_main.NUM_IMAGES['train'], warmup_epochs=LR_SCHEDULE[0][1], boundaries=list(p[1] for p in LR_SCHEDULE[1:]), multipliers=list(p[0] for p in LR_SCHEDULE), compute_lr_on_cpu=True) with strategy_scope: optimizer = keras_common.get_optimizer(lr_schedule) if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj)) if flags_obj.enable_xla and not flags_obj.enable_eager: # TODO(b/129861005): Fix OOM issue in eager mode when setting # `batch_size` in keras.Input layer. if strategy and strategy.num_replicas_in_sync > 1: # TODO(b/129791381): Specify `input_layer_batch_size` value in # DistributionStrategy multi-replica case. input_layer_batch_size = None else: input_layer_batch_size = flags_obj.batch_size else: input_layer_batch_size = None if flags_obj.use_trivial_model: model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES, dtype) else: model = resnet_model.resnet50( num_classes=imagenet_main.NUM_CLASSES, dtype=dtype, batch_size=input_layer_batch_size) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), cloning=flags_obj.clone_model_in_keras_dist_strat) callbacks = keras_common.get_callbacks( learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) stats = keras_common.build_stats(history, eval_output, callbacks) return stats
def __init__(self, flags_obj): """Init function of TransformerMain. Args: flags_obj: Object containing parsed flag values, i.e., FLAGS. Raises: ValueError: if not using static batch for input data on TPU. """ self.flags_obj = flags_obj self.predict_model = None # Add flag-defined parameters to params object num_gpus = flags_core.get_num_gpus(flags_obj) self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus) params["num_gpus"] = num_gpus params["use_ctl"] = flags_obj.use_ctl params["data_dir"] = flags_obj.data_dir params["model_dir"] = flags_obj.model_dir params["static_batch"] = flags_obj.static_batch params["max_length"] = flags_obj.max_length params["decode_batch_size"] = flags_obj.decode_batch_size params["decode_max_length"] = flags_obj.decode_max_length params["padded_decode"] = flags_obj.padded_decode params["num_parallel_calls"] = ( flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE) params["use_synthetic_data"] = flags_obj.use_synthetic_data params["batch_size"] = flags_obj.batch_size or params["default_batch_size"] params["repeat_dataset"] = None params["dtype"] = flags_core.get_tf_dtype(flags_obj) params["enable_tensorboard"] = flags_obj.enable_tensorboard params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training params["steps_between_evals"] = flags_obj.steps_between_evals num_workers = distribution_utils.configure_cluster( flags_obj.worker_hosts, flags_obj.task_index) self.distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=num_gpus, num_workers=num_workers, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu or "") if self.use_tpu: params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync if not params["static_batch"]: raise ValueError("TPU requires static batch for input data.") else: logging.info("Running transformer with num_gpus = %d", num_gpus) if self.distribution_strategy: logging.info("For training, using distribution strategy: %s", self.distribution_strategy) else: logging.info("Not using any distribution strategy.") if params["dtype"] == tf.float16: # TODO(reedwm): It's pretty ugly to set the global policy in a constructor # like this. What if multiple instances of TransformerTask are created? # We should have a better way in the tf.keras.mixed_precision API of doing # this. loss_scale = flags_core.get_loss_scale( flags_obj, default_for_fp16="dynamic") policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( "mixed_float16", loss_scale=loss_scale) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) elif params["dtype"] == tf.bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( "mixed_bfloat16") tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ config = keras_common.get_config_proto() # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # Eager is default in tf 2.0 and should not be toggled if not keras_common.is_v2_0(): if flags_obj.enable_eager: tf.compat.v1.enable_eager_execution(config=config) else: sess = tf.Session(config=config) tf.keras.backend.set_session(sess) # TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready. dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy( 'infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=dtype) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_main.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype) eval_input_dataset = input_fn(is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, dtype=dtype) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster()) strategy_scope = keras_common.get_strategy_scope(strategy) with strategy_scope: optimizer = keras_common.get_optimizer() if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj)) model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES, dtype=dtype) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['sparse_categorical_accuracy']) time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit( train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=[time_callback, lr_callback, tensorboard_callback], validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) stats = keras_common.build_stats(history, eval_output, time_callback) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # Eager is default in tf 2.0 and should not be toggled if keras_common.is_v2_0(): keras_common.set_config_v2() else: config = keras_common.get_config_proto_v1() if flags_obj.enable_eager: tf.compat.v1.enable_eager_execution(config=config) else: sess = tf.Session(config=config) tf.keras.backend.set_session(sess) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_common.set_gpu_thread_mode_and_count(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy( 'infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster()) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_main.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn(is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, dtype=dtype, drop_remainder=drop_remainder) with strategy_scope: optimizer = keras_common.get_optimizer() if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj)) if flags_obj.enable_xla and not flags_obj.enable_eager: # TODO(b/129861005): Fix OOM issue in eager mode when setting # `batch_size` in keras.Input layer. if strategy and strategy.num_replicas_in_sync > 1: # TODO(b/129791381): Specify `per_replica_batch_size` value in # DistributionStrategy multi-replica case. per_replica_batch_size = None else: per_replica_batch_size = flags_obj.batch_size else: per_replica_batch_size = None if flags_obj.use_trivial_model: model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES) else: model = resnet_model.resnet50( num_classes=imagenet_main.NUM_CLASSES, dtype=dtype, batch_size=per_replica_batch_size) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['sparse_categorical_accuracy']) callbacks = keras_common.get_callbacks(learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) stats = keras_common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=cifar_preprocessing.HEIGHT, width=cifar_preprocessing.WIDTH, num_channels=cifar_preprocessing.NUM_CHANNELS, num_classes=cifar_preprocessing.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj), drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = cifar_preprocessing.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, # Setting drop_remainder to avoid the partial batch logic in normalization # layer, which triggers tf.where and leads to extra memory copy of input # sizes between host and GPU. drop_remainder=(not flags_obj.enable_get_next_as_optional)) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_preprocessing.parse_record) with strategy_scope: optimizer = common.get_optimizer() model = resnet_cifar_model.resnet56( classes=cifar_preprocessing.NUM_CLASSES) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj. force_v2_in_keras_compile) else: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks(steps_per_epoch, learning_rate_schedule) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def run_executor(params, mode, checkpoint_path=None, train_input_fn=None, eval_input_fn=None, callbacks=None, prebuilt_strategy=None): """Runs the object detection model on distribution strategy defined by the user.""" if params.architecture.use_bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) model_builder = model_factory.model_generator(params) if prebuilt_strategy is not None: strategy = prebuilt_strategy else: strategy_config = params.strategy_config distribution_utils.configure_cluster(strategy_config.worker_hosts, strategy_config.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=params.strategy_type, num_gpus=strategy_config.num_gpus, all_reduce_alg=strategy_config.all_reduce_alg, num_packs=strategy_config.num_packs, tpu_address=strategy_config.tpu) num_workers = int(strategy.num_replicas_in_sync + 7) // 8 is_multi_host = (int(num_workers) >= 2) if mode == 'train': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.TRAIN) logging.info( 'Train num_replicas_in_sync %d num_workers %d is_multi_host %s', strategy.num_replicas_in_sync, num_workers, is_multi_host) dist_executor = DetectionDistributedExecutor( strategy=strategy, params=params, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, is_multi_host=is_multi_host, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder. make_filter_trainable_variables_fn()) if is_multi_host: train_input_fn = functools.partial( train_input_fn, batch_size=params.train.batch_size // strategy.num_replicas_in_sync) return dist_executor.train( train_input_fn=train_input_fn, model_dir=params.model_dir, iterations_per_loop=params.train.iterations_per_loop, total_steps=params.train.total_steps, init_checkpoint=model_builder.make_restore_checkpoint_fn(), custom_callbacks=callbacks, save_config=True) elif mode == 'eval' or mode == 'eval_once': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT) logging.info( 'Eval num_replicas_in_sync %d num_workers %d is_multi_host %s', strategy.num_replicas_in_sync, num_workers, is_multi_host) if is_multi_host: eval_input_fn = functools.partial( eval_input_fn, batch_size=params.eval.batch_size // strategy.num_replicas_in_sync) dist_executor = DetectionDistributedExecutor( strategy=strategy, params=params, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, is_multi_host=is_multi_host, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder. make_filter_trainable_variables_fn()) if mode == 'eval': results = dist_executor.evaluate_from_model_dir( model_dir=params.model_dir, eval_input_fn=eval_input_fn, eval_metric_fn=model_builder.eval_metrics, eval_timeout=params.eval.eval_timeout, min_eval_interval=params.eval.min_eval_interval, total_steps=params.train.total_steps) else: # Run evaluation once for a single checkpoint. if not checkpoint_path: raise ValueError('checkpoint_path cannot be empty.') if tf.io.gfile.isdir(checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) summary_writer = executor.SummaryWriter(params.model_dir, 'eval') results, _ = dist_executor.evaluate_checkpoint( checkpoint_path=checkpoint_path, eval_input_fn=eval_input_fn, eval_metric_fn=model_builder.eval_metrics, summary_writer=summary_writer) for k, v in results.items(): logging.info('Final eval metric %s: %f', k, v) return results else: raise ValueError('Mode not found: %s.' % mode)
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=cifar_main.HEIGHT, width=cifar_main.WIDTH, num_channels=cifar_main.NUM_CHANNELS, num_classes=cifar_main.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj)) else: distribution_utils.undo_set_up_synthetic_data() input_fn = cifar_main.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_main.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn(is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_main.parse_record) with strategy_scope: optimizer = keras_common.get_optimizer() model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) callbacks = keras_common.get_callbacks(learning_rate_schedule, cifar_main.NUM_IMAGES['train']) train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (cifar_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = keras_common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: common.set_gpu_thread_mode_and_count(flags_obj) if flags_obj.data_delay_prefetch: common.data_delay_prefetch() common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy( 'infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=num_workers, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, dtype=dtype, drop_remainder=drop_remainder) lr_schedule = 0.1 if flags_obj.use_tensor_lr: lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=LR_SCHEDULE[0][1], boundaries=list(p[1] for p in LR_SCHEDULE[1:]), multipliers=list(p[0] for p in LR_SCHEDULE), compute_lr_on_cpu=True) with strategy_scope: optimizer = common.get_optimizer(lr_schedule) if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) if flags_obj.use_trivial_model: model = trivial_model.trivial_model( imagenet_preprocessing.NUM_CLASSES, dtype) else: model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj. force_v2_in_keras_compile) else: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) callbacks = common.get_callbacks( learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train']) train_steps = (imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps // 15, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=1) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=1) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def run_train(flags_obj): keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads) common.set_cudnn_batchnorm_mode() performance.set_mixed_precision_policy( flags_core.get_tf_dtype(flags_obj), flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) distribution_utils.undo_set_up_synthetic_data() train_input_dataset, eval_input_dataset, tr_dataset, te_dataset = setup_datasets(flags_obj) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=GB_OPTIONS.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] // GB_OPTIONS.batch_size) with strategy_scope: optimizer = common.get_optimizer(lr_schedule) model = build_model(imagenet_preprocessing.NUM_CLASSES, mode='resnet50') if GB_OPTIONS.pretrained_filepath is not None: latest = tf.train.latest_checkpoint(GB_OPTIONS.pretrained_filepath) print(latest) model.load_weights(latest) #losses = ["sparse_categorical_crossentropy"] #lossWeights = [1.0] model.compile( optimizer=optimizer, loss="sparse_categorical_crossentropy", #loss_weights=lossWeights, metrics=['sparse_categorical_accuracy']) train_epochs = GB_OPTIONS.num_epochs if not hasattr(tr_dataset, "n_poison"): n_poison=0 n_cover=0 else: n_poison = tr_dataset.n_poison n_cover = tr_dataset.n_cover callbacks = common.get_callbacks( steps_per_epoch=steps_per_epoch, pruning_method=flags_obj.pruning_method, enable_checkpoint_and_export=False, model_dir=GB_OPTIONS.checkpoint_folder ) ckpt_full_path = os.path.join(GB_OPTIONS.checkpoint_folder, 'model.ckpt-{epoch:04d}-p%d-c%d'%(n_poison,n_cover)) callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True, save_best_only=True)) num_eval_steps = imagenet_preprocessing.NUM_IMAGES['validation'] // GB_OPTIONS.batch_size if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None eval_input_dataset = None history = model.fit( train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=eval_input_dataset, validation_freq=flags_obj.epochs_between_evals ) export_path = os.path.join(GB_OPTIONS.checkpoint_folder, 'saved_model') model.save(export_path, include_optimizer=False) eval_output = model.evaluate( eval_input_dataset, steps=num_eval_steps, verbose=2 ) stats = common.build_stats(history, eval_output, callbacks) cmmd = 'cp config.py '+GB_OPTIONS.checkpoint_folder os.system(cmmd) return stats
def resnet_main(flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. Dict of results of the run. Contains the keys `eval_results` and `train_hooks`. `eval_results` contains accuracy (top_1) and accuracy_top_5. `train_hooks` is a list the instances of hooks used during training. """ model_helpers.apply_clean(flags.FLAGS) # Ensures flag override logic is only executed if explicitly triggered. if flags_obj.tf_gpu_thread_mode: override_flags_and_set_envars_for_gpu_thread_pool(flags_obj) # Configures cluster spec for distribution strategy. num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) # Creates session config. allow_soft_placement = True, is required for # multi-GPU and is not harmful for other modes. session_config = tf.compat.v1.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_core.get_num_gpus(flags_obj), num_workers=num_workers, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) # Creates a `RunConfig` that checkpoints every 24 hours which essentially # results in checkpoints determined only by `epochs_between_evals`. run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy, session_config=session_config, save_checkpoints_secs=60 * 60 * 24, save_checkpoints_steps=None) # Initializes model with all but the dense layer from pretrained ResNet. if flags_obj.pretrained_model_checkpoint_path is not None: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start='^(?!.*dense)') else: warm_start_settings = None classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, warm_start_from=warm_start_settings, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj), 'fine_tune': flags_obj.fine_tune, 'num_workers': num_workers, }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, 'num_workers': num_workers, } if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('resnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) def input_fn_train(num_epochs, input_context=None): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_replica_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, dtype=flags_core.get_tf_dtype(flags_obj), datasets_num_private_threads=flags_obj. datasets_num_private_threads, num_parallel_batches=flags_obj.datasets_num_parallel_batches, input_context=input_context) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_replica_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_tf_dtype(flags_obj)) train_epochs = (0 if flags_obj.eval_only or not flags_obj.train_epochs else flags_obj.train_epochs) use_train_and_evaluate = flags_obj.use_train_and_evaluate or num_workers > 1 if use_train_and_evaluate: train_spec = tf.estimator.TrainSpec( input_fn=lambda input_context=None: input_fn_train( train_epochs, input_context=input_context), hooks=train_hooks, max_steps=flags_obj.max_train_steps) eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval) tf.compat.v1.logging.info('Starting to train and evaluate.') tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec) # tf.estimator.train_and_evalute doesn't return anything in multi-worker # case. return {} else: if train_epochs == 0: # If --eval_only is set, perform a single loop with zero train epochs. schedule, n_loops = [0], 1 else: # Compute the number of times to loop while training. All but the last # pass will train for `epochs_between_evals` epochs, while the last will # train for the number needed to reach `training_epochs`. For instance if # train_epochs = 25 and epochs_between_evals = 10 # schedule will be set to [10, 10, 5]. That is to say, the loop will: # Train for 10 epochs and then evaluate. # Train for another 10 epochs and then evaluate. # Train for a final 5 epochs (to reach 25 epochs) and then evaluate. n_loops = math.ceil(train_epochs / flags_obj.epochs_between_evals) schedule = [ flags_obj.epochs_between_evals for _ in range(int(n_loops)) ] schedule[-1] = train_epochs - sum(schedule[:-1]) # over counting. for cycle_index, num_train_epochs in enumerate(schedule): tf.compat.v1.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) if num_train_epochs: # Since we are calling classifier.train immediately in each loop, the # value of num_train_epochs in the lambda function will not be changed # before it is used. So it is safe to ignore the pylint error here # pylint: disable=cell-var-from-loop classifier.train( input_fn=lambda input_context=None: input_fn_train( num_train_epochs, input_context=input_context), hooks=train_hooks, max_steps=flags_obj.max_train_steps) # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, # which will iterate forever. Passing steps=flags_obj.max_train_steps # allows the eval (which is generally unimportant in those circumstances) # to terminate. Note that eval will run for max_train_steps each loop, # regardless of the global_step count. tf.compat.v1.logging.info('Starting to evaluate.') eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. export_dtype = flags_core.get_tf_dtype(flags_obj) if flags_obj.image_bytes_as_serving_input: input_receiver_fn = functools.partial(image_bytes_serving_input_fn, shape, dtype=export_dtype) else: input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size, dtype=export_dtype) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn, strip_default_attrs=True) stats = {} stats['eval_results'] = eval_results stats['train_hooks'] = train_hooks return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ print('@@@@enable_eager = {}'.format(flags_obj.enable_eager)) keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) dtype = flags_core.get_tf_dtype(flags_obj) if dtype == tf.float16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_float16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) elif dtype == tf.bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) # This only affects GPU. common.set_cudnn_batchnorm_mode() # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) train_ds, test_ds = get_input_dataset(flags_obj, strategy) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) logging.info("Training %d epochs, each epoch has %d steps, " "total steps: %d; Eval %d steps", train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory(flags_obj.batch_size, flags_obj.log_steps) with distribution_utils.get_strategy_scope(strategy): resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers) use_l2_regularizer = not flags_obj.single_l2_loss_op if flags_obj.use_resnet_d: resnetd = network_tweaks.ResnetD(image_data_format=tf.keras.backend.image_data_format(), use_l2_regularizer=use_l2_regularizer) else: resnetd = None model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES, batch_size=flags_obj.batch_size, zero_gamma=flags_obj.zero_gamma, last_pool_channel_type=flags_obj.last_pool_channel_type, use_l2_regularizer=use_l2_regularizer, resnetd=resnetd) if flags_obj.learning_rate_decay_type == 'piecewise': lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) elif flags_obj.learning_rate_decay_type == 'cosine': lr_schedule = common.CosineDecayWithWarmup( base_lr=flags_obj.base_learning_rate, batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], train_epochs=flags_obj.train_epochs, compute_lr_on_cpu=True) else: raise NotImplementedError optimizer = common.get_optimizer(lr_schedule) if dtype == tf.float16: loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale) elif flags_obj.fp16_implementation == 'graph_rewrite': # `dtype` is still float32 in this case. We built the graph in float32 and # let the graph rewrite change parts of it float16. if not flags_obj.use_tf_function: raise ValueError('--fp16_implementation=graph_rewrite requires ' '--use_tf_function to be true') loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale) current_step = 0 checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir) if latest_checkpoint: checkpoint.restore(latest_checkpoint) logging.info("Load checkpoint %s", latest_checkpoint) current_step = optimizer.iterations.numpy() train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32) categorical_cross_entopy_and_acc = losses.CategoricalCrossEntropyAndAcc( batch_size=flags_obj.batch_size, num_classes=imagenet_preprocessing.NUM_CLASSES, label_smoothing=flags_obj.label_smoothing) trainable_variables = model.trainable_variables def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs with tf.GradientTape() as tape: logits = model(images, training=True) loss = categorical_cross_entopy_and_acc.loss_and_update_acc(labels, logits, training=True) #loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size) num_replicas = tf.distribute.get_strategy().num_replicas_in_sync if flags_obj.single_l2_loss_op: l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([ tf.nn.l2_loss(v) for v in trainable_variables if 'bn' not in v.name ]) loss += (l2_loss / num_replicas) else: loss += (tf.reduce_sum(model.losses) / num_replicas) # Scale the loss if flags_obj.dtype == "fp16": loss = optimizer.get_scaled_loss(loss) grads = tape.gradient(loss, trainable_variables) # Unscale the grads if flags_obj.dtype == "fp16": grads = optimizer.get_unscaled_gradients(grads) optimizer.apply_gradients(zip(grads, trainable_variables)) train_loss.update_state(loss) @tf.function def train_steps(iterator, steps): """Performs distributed training steps in a loop.""" for _ in tf.range(steps): strategy.experimental_run_v2(step_fn, args=(next(iterator),)) def train_single_step(iterator): if strategy: strategy.experimental_run_v2(step_fn, args=(next(iterator),)) else: return step_fn(next(iterator)) def test_step(iterator): """Evaluation StepFn.""" def step_fn(inputs): images, labels = inputs logits = model(images, training=False) loss = categorical_cross_entopy_and_acc.loss_and_update_acc(labels, logits, training=False) #loss = tf.reduce_sum(loss) * (1.0/ flags_obj.batch_size) test_loss.update_state(loss) if strategy: strategy.experimental_run_v2(step_fn, args=(next(iterator),)) else: step_fn(next(iterator)) if flags_obj.use_tf_function: train_single_step = tf.function(train_single_step) test_step = tf.function(test_step) if flags_obj.enable_tensorboard: summary_writer = tf.summary.create_file_writer(flags_obj.model_dir) else: summary_writer = None train_iter = iter(train_ds) time_callback.on_train_begin() for epoch in range(current_step // per_epoch_steps, train_epochs): train_loss.reset_states() categorical_cross_entopy_and_acc.training_accuracy.reset_states() steps_in_current_epoch = 0 while steps_in_current_epoch < per_epoch_steps: time_callback.on_batch_begin( steps_in_current_epoch+epoch*per_epoch_steps) steps = _steps_to_run(steps_in_current_epoch, per_epoch_steps, steps_per_loop) if steps == 1: train_single_step(train_iter) else: # Converts steps to a Tensor to avoid tf.function retracing. train_steps(train_iter, tf.convert_to_tensor(steps, dtype=tf.int32)) time_callback.on_batch_end( steps_in_current_epoch+epoch*per_epoch_steps) steps_in_current_epoch += steps #temp_loss = array_ops.identity(categorical_cross_entopy_and_acc.training_loss).numpy() #temp_loss = categorical_cross_entopy_and_acc.training_loss.numpy() logging.info('Training loss: %s, accuracy: %s, cross_entropy: %s at epoch %d', train_loss.result().numpy(), categorical_cross_entopy_and_acc.training_accuracy.result().numpy(), 0., epoch + 1) if (not flags_obj.skip_eval and (epoch + 1) % flags_obj.epochs_between_evals == 0): test_loss.reset_states() categorical_cross_entopy_and_acc.test_accuracy.reset_states() test_iter = iter(test_ds) for _ in range(eval_steps): test_step(test_iter) logging.info('Test loss: %s, accuracy: %s%% at epoch: %d', test_loss.result().numpy(), categorical_cross_entopy_and_acc.test_accuracy.result().numpy(), epoch + 1) if flags_obj.enable_checkpoint_and_export: checkpoint_name = checkpoint.save( os.path.join(flags_obj.model_dir, 'model.ckpt-{}'.format(epoch + 1))) logging.info('Saved checkpoint to %s', checkpoint_name) if summary_writer: current_steps = steps_in_current_epoch + (epoch * per_epoch_steps) with summary_writer.as_default(): #tf.summary.scalar('train_cross_entropy', categorical_cross_entopy_and_acc.training_loss.numpy(), current_steps) tf.summary.scalar('train_loss', train_loss.result(), current_steps) tf.summary.scalar('train_accuracy', categorical_cross_entopy_and_acc.training_accuracy.result(), current_steps) lr_for_monitor = lr_schedule(current_steps) if callable(lr_for_monitor): lr_for_monitor = lr_for_monitor() tf.summary.scalar('learning_rate', lr_for_monitor, current_steps) tf.summary.scalar('eval_loss', test_loss.result(), current_steps) tf.summary.scalar( 'eval_accuracy', categorical_cross_entopy_and_acc.test_accuracy.result(), current_steps) time_callback.on_train_end() if summary_writer: summary_writer.close() eval_result = None train_result = None if not flags_obj.skip_eval: eval_result = [test_loss.result().numpy(), categorical_cross_entopy_and_acc.test_accuracy.result().numpy()] train_result = [train_loss.result().numpy(), categorical_cross_entopy_and_acc.training_accuracy.result().numpy()] stats = build_stats(train_result, eval_result, time_callback) return stats