def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. NotImplementedError: If some features are not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) performance.set_mixed_precision_policy( flags_core.get_tf_dtype(flags_obj), flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.config.list_physical_devices('GPU') else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla # Current resnet_model.resnet50 input format is always channel-last. # We use keras_application mobilenet model which input format is depends on # the keras beckend image data format. # This use_keras_image_data_format flags indicates whether image preprocessor # output format should be same as the keras backend image data format or just # channel-last format. use_keras_image_data_format = (flags_obj.model == 'mobilenet') train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, training_dataset_cache=flags_obj.training_dataset_cache, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=imagenet_preprocessing.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), dtype=dtype, drop_remainder=drop_remainder) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) steps_per_epoch = ( imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) with strategy_scope: if flags_obj.optimizer == 'resnet50_default': optimizer = common.get_optimizer(lr_schedule) elif flags_obj.optimizer == 'mobilenet_default': initial_learning_rate = \ flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size optimizer = tf.keras.optimizers.SGD( learning_rate=tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay, decay_rate=flags_obj.lr_decay_factor, staircase=True), momentum=0.9) if flags_obj.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) # TODO(hongkuny): Remove trivial model usage and move it to benchmark. if flags_obj.use_trivial_model: model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES) elif flags_obj.model == 'resnet50_v1.5': model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES) elif flags_obj.model == 'mobilenet': # TODO(kimjaehong): Remove layers attribute when minimum TF version # support 2.0 layers by default. model = tf.keras.applications.mobilenet.MobileNet( weights=None, classes=imagenet_preprocessing.NUM_CLASSES, layers=tf.keras.layers) if flags_obj.pretrained_filepath: model.load_weights(flags_obj.pretrained_filepath) if flags_obj.pruning_method == 'polynomial_decay': if dtype != tf.float32: raise NotImplementedError( 'Pruning is currently only supported on dtype=tf.float32.') pruning_params = { 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay( initial_sparsity=flags_obj.pruning_initial_sparsity, final_sparsity=flags_obj.pruning_final_sparsity, begin_step=flags_obj.pruning_begin_step, end_step=flags_obj.pruning_end_step, frequency=flags_obj.pruning_frequency), } model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params) elif flags_obj.pruning_method: raise NotImplementedError( 'Only polynomial_decay is currently supported.') model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks( pruning_method=flags_obj.pruning_method, enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export, model_dir=flags_obj.model_dir) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = ( imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if flags_obj.pruning_method: model = tfmot.sparsity.keras.strip_pruning(model) if flags_obj.enable_checkpoint_and_export: if dtype == tf.bfloat16: logging.warning('Keras model.save does not support bfloat16 dtype.') else: # Keras model.save assumes a float32 input designature. export_path = os.path.join(flags_obj.model_dir, 'saved_model') model.save(export_path, include_optimizer=False) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.config.list_physical_devices('GPU') else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: synthetic_util.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=cifar_preprocessing.HEIGHT, width=cifar_preprocessing.WIDTH, num_channels=cifar_preprocessing.NUM_CHANNELS, num_classes=cifar_preprocessing.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj), drop_remainder=True) else: synthetic_util.undo_set_up_synthetic_data() input_fn = cifar_preprocessing.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=cifar_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, # Setting drop_remainder to avoid the partial batch logic in normalization # layer, which triggers tf.where and leads to extra memory copy of input # sizes between host and GPU. drop_remainder=(not flags_obj.enable_get_next_as_optional)) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=cifar_preprocessing.parse_record) steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) lr_schedule = 0.1 if flags_obj.use_tensor_lr: initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128 lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay( boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE), values=[initial_learning_rate] + list(p[0] * initial_learning_rate for p in LR_SCHEDULE)) with strategy_scope: optimizer = common.get_optimizer(lr_schedule) model = resnet_cifar_model.resnet56( classes=cifar_preprocessing.NUM_CLASSES) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks(steps_per_epoch) if not flags_obj.use_tensor_lr: lr_callback = LearningRateBatchScheduler( schedule=learning_rate_schedule, batch_size=flags_obj.batch_size, steps_per_epoch=steps_per_epoch) callbacks.append(lr_callback) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def run_train(flags_obj): keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads) common.set_cudnn_batchnorm_mode() performance.set_mixed_precision_policy( flags_core.get_tf_dtype(flags_obj), flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) distribution_utils.undo_set_up_synthetic_data() train_input_dataset, eval_input_dataset, tr_dataset, te_dataset = setup_datasets(flags_obj) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=GB_OPTIONS.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] // GB_OPTIONS.batch_size) with strategy_scope: optimizer = common.get_optimizer(lr_schedule) model = build_model(imagenet_preprocessing.NUM_CLASSES, mode='resnet50') if GB_OPTIONS.pretrained_filepath is not None: latest = tf.train.latest_checkpoint(GB_OPTIONS.pretrained_filepath) print(latest) model.load_weights(latest) #losses = ["sparse_categorical_crossentropy"] #lossWeights = [1.0] model.compile( optimizer=optimizer, loss="sparse_categorical_crossentropy", #loss_weights=lossWeights, metrics=['sparse_categorical_accuracy']) train_epochs = GB_OPTIONS.num_epochs if not hasattr(tr_dataset, "n_poison"): n_poison=0 n_cover=0 else: n_poison = tr_dataset.n_poison n_cover = tr_dataset.n_cover callbacks = common.get_callbacks( steps_per_epoch=steps_per_epoch, pruning_method=flags_obj.pruning_method, enable_checkpoint_and_export=False, model_dir=GB_OPTIONS.checkpoint_folder ) ckpt_full_path = os.path.join(GB_OPTIONS.checkpoint_folder, 'model.ckpt-{epoch:04d}-p%d-c%d'%(n_poison,n_cover)) callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True, save_best_only=True)) num_eval_steps = imagenet_preprocessing.NUM_IMAGES['validation'] // GB_OPTIONS.batch_size if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None eval_input_dataset = None history = model.fit( train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=eval_input_dataset, validation_freq=flags_obj.epochs_between_evals ) export_path = os.path.join(GB_OPTIONS.checkpoint_folder, 'saved_model') model.save(export_path, include_optimizer=False) eval_output = model.evaluate( eval_input_dataset, steps=num_eval_steps, verbose=2 ) stats = common.build_stats(history, eval_output, callbacks) cmmd = 'cp config.py '+GB_OPTIONS.checkpoint_folder os.system(cmmd) return stats