def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_tf_dtype(flags_obj))
def input_fn_train(num_epochs): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, dtype=flags_core.get_tf_dtype(flags_obj), datasets_num_private_threads=flags_obj.datasets_num_private_threads, num_parallel_batches=flags_obj.datasets_num_parallel_batches)
def run_cifar(flags_obj): """Run ResNet CIFAR-10 training and eval loop. Args: flags_obj: An object containing parsed flag values. """ input_function = (flags_obj.use_synthetic_data and get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or input_fn) resnet_run_loop.resnet_main( flags_obj, cifar10_model_fn, input_function, DATASET_NAME, shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
def run_imagenet(flags_obj): """Run ResNet ImageNet training and eval loop. Args: flags_obj: An object containing parsed flag values. """ input_function = (flags_obj.use_synthetic_data and get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or input_fn) resnet_run_loop.resnet_main( flags_obj, imagenet_model_fn, input_function, DATASET_NAME, shape=[DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE, NUM_CHANNELS])
def test_parse_dtype_info(self): for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128], ["fp32", tf.float32, 1]]: flags_core.parse_flags([__file__, "--dtype", dtype_str]) self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf_dtype) self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), loss_scale) flags_core.parse_flags( [__file__, "--dtype", dtype_str, "--loss_scale", "5"]) self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5) with self.assertRaises(SystemExit): flags_core.parse_flags([__file__, "--dtype", "int8"])
def run_cifar(flags_obj): """Run ResNet CIFAR-10 training and eval loop. Args: flags_obj: An object containing parsed flag values. Returns: Dictionary of results. Including final accuracy. """ if flags_obj.image_bytes_as_serving_input: tf.logging.fatal('--image_bytes_as_serving_input cannot be set to True ' 'for CIFAR. This flag is only applicable to ImageNet.') return input_function = (flags_obj.use_synthetic_data and get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or input_fn) result = resnet_run_loop.resnet_main( flags_obj, cifar10_model_fn, input_function, DATASET_NAME, shape=[HEIGHT, WIDTH, NUM_CHANNELS]) return result
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ if flags_obj.enable_eager: tf.enable_eager_execution() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError('dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) if flags_obj.use_synthetic_data: input_fn = keras_common.get_synth_input_fn( height=cifar_main.HEIGHT, width=cifar_main.WIDTH, num_channels=cifar_main.NUM_CHANNELS, num_classes=cifar_main.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj)) else: input_fn = cifar_main.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) strategy = distribution_utils.get_distribution_strategy( num_gpus=flags_obj.num_gpus, turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy) strategy_scope = keras_common.get_strategy_scope(strategy) with strategy_scope: optimizer = keras_common.get_optimizer() model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES) model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['categorical_accuracy']) time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( learning_rate_schedule, cifar_main.NUM_IMAGES['train']) train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (cifar_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=[ time_callback, lr_callback, tensorboard_callback ], validation_steps=num_eval_steps, validation_data=validation_data, verbose=1) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=1) stats = keras_common.build_stats(history, eval_output, time_callback) return stats
def __init__(self, flags_obj, time_callback, epoch_steps): standard_runnable.StandardTrainable.__init__( self, flags_obj.use_tf_while_loop, flags_obj.use_tf_function) standard_runnable.StandardEvaluable.__init__(self, flags_obj.use_tf_function) self.strategy = tf.distribute.get_strategy() self.flags_obj = flags_obj self.dtype = flags_core.get_tf_dtype(flags_obj) self.time_callback = time_callback # Input pipeline related batch_size = flags_obj.batch_size if batch_size % self.strategy.num_replicas_in_sync != 0: raise ValueError( 'Batch size must be divisible by number of replicas : {}'. format(self.strategy.num_replicas_in_sync)) # As auto rebatching is not supported in # `experimental_distribute_datasets_from_function()` API, which is # required when cloning dataset to multiple workers in eager mode, # we use per-replica batch size. self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync) if self.flags_obj.use_synthetic_data: self.input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=self.dtype, drop_remainder=True) else: self.input_fn = imagenet_preprocessing.input_fn self.model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES, batch_size=flags_obj.batch_size, use_l2_regularizer=not flags_obj.single_l2_loss_op) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) self.optimizer = common.get_optimizer(lr_schedule) # Make sure iterations variable is created inside scope. self.global_step = self.optimizer.iterations use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite' if use_graph_rewrite and not flags_obj.use_tf_function: raise ValueError('--fp16_implementation=graph_rewrite requires ' '--use_tf_function to be true') self.optimizer = performance.configure_optimizer( self.optimizer, use_float16=self.dtype == tf.float16, use_graph_rewrite=use_graph_rewrite, loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'train_accuracy', dtype=tf.float32) self.test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32) self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'test_accuracy', dtype=tf.float32) self.checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer) # Handling epochs. self.epoch_steps = epoch_steps self.epoch_helper = utils.EpochHelper(epoch_steps, self.global_step)
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: common.set_gpu_thread_mode_and_count(flags_obj) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError('dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=cifar_preprocessing.HEIGHT, width=cifar_preprocessing.WIDTH, num_channels=cifar_preprocessing.NUM_CHANNELS, num_classes=cifar_preprocessing.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj), drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = cifar_preprocessing.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, # Setting drop_remainder to avoid the partial batch logic in normalization # layer, which triggers tf.where and leads to extra memory copy of input # sizes between host and GPU. drop_remainder=(not flags_obj.enable_get_next_as_optional)) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=cifar_preprocessing.parse_record) with strategy_scope: optimizer = common.get_optimizer() model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='categorical_crossentropy', optimizer=optimizer, metrics=(['categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) else: model.compile( loss='categorical_crossentropy', optimizer=optimizer, metrics=(['categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) callbacks = common.get_callbacks( learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train']) train_steps = cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def __init__(self, flags_obj): """Init function of TransformerMain. Args: flags_obj: Object containing parsed flag values, i.e., FLAGS. Raises: ValueError: if not using static batch for input data on TPU. """ self.flags_obj = flags_obj self.predict_model = None # Add flag-defined parameters to params object num_gpus = flags_core.get_num_gpus(flags_obj) self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus) params["num_gpus"] = num_gpus params["use_ctl"] = flags_obj.use_ctl params["data_dir"] = flags_obj.data_dir params["model_dir"] = flags_obj.model_dir params["static_batch"] = flags_obj.static_batch params["max_length"] = flags_obj.max_length params["decode_batch_size"] = flags_obj.decode_batch_size params["decode_max_length"] = flags_obj.decode_max_length params["padded_decode"] = flags_obj.padded_decode params["max_io_parallelism"] = (flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE) params["use_synthetic_data"] = flags_obj.use_synthetic_data params["batch_size"] = flags_obj.batch_size or params[ "default_batch_size"] params["repeat_dataset"] = None params["dtype"] = flags_core.get_tf_dtype(flags_obj) params["enable_tensorboard"] = flags_obj.enable_tensorboard params[ "enable_metrics_in_training"] = flags_obj.enable_metrics_in_training params["steps_between_evals"] = flags_obj.steps_between_evals params["enable_checkpointing"] = flags_obj.enable_checkpointing params["save_weights_only"] = flags_obj.save_weights_only self.distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu or "") if self.use_tpu: params[ "num_replicas"] = self.distribution_strategy.num_replicas_in_sync else: logging.info("Running transformer with num_gpus = %d", num_gpus) if self.distribution_strategy: logging.info("For training, using distribution strategy: %s", self.distribution_strategy) else: logging.info("Not using any distribution strategy.") performance.set_mixed_precision_policy( params["dtype"], flags_core.get_loss_scale(flags_obj, default_for_fp16="dynamic"))
def resnet_main(flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. """ model_helpers.apply_clean(flags.FLAGS) # Ensures flag override logic is only executed if explicitly triggered. if flags_obj.tf_gpu_thread_mode: override_flags_and_set_envars_for_gpu_thread_pool(flags_obj) # Creates session config. allow_soft_placement = True, is required for # multi-GPU and is not harmful for other modes. session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) # Creates a `RunConfig` that checkpoints every 24 hours which essentially # results in checkpoints determined only by `epochs_between_evals`. run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy, session_config=session_config, save_checkpoints_secs=60 * 60 * 24) # Initializes model with all but the dense layer from pretrained ResNet. if flags_obj.pretrained_model_checkpoint_path is not None: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start='^(?!.*dense)') else: warm_start_settings = None classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, warm_start_from=warm_start_settings, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj), 'fine_tune': flags_obj.fine_tune }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, } if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('resnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) def input_fn_train(num_epochs): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, dtype=flags_core.get_tf_dtype(flags_obj), datasets_num_private_threads=flags_obj. datasets_num_private_threads, num_parallel_batches=flags_obj.datasets_num_parallel_batches) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_tf_dtype(flags_obj)) if flags_obj.eval_only or not flags_obj.train_epochs: # If --eval_only is set, perform a single loop with zero train epochs. schedule, n_loops = [0], 1 else: # Compute the number of times to loop while training. All but the last # pass will train for `epochs_between_evals` epochs, while the last will # train for the number needed to reach `training_epochs`. For instance if # train_epochs = 25 and epochs_between_evals = 10 # schedule will be set to [10, 10, 5]. That is to say, the loop will: # Train for 10 epochs and then evaluate. # Train for another 10 epochs and then evaluate. # Train for a final 5 epochs (to reach 25 epochs) and then evaluate. n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals) schedule = [ flags_obj.epochs_between_evals for _ in range(int(n_loops)) ] schedule[-1] = flags_obj.train_epochs - sum( schedule[:-1]) # over counting. # generate json file under current directory hooks = [ tf.train.ProfilerHook(output_dir='.', save_secs=600, show_memory=False) ] for cycle_index, num_train_epochs in enumerate(schedule): tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) if num_train_epochs: classifier.train(input_fn=lambda: input_fn_train(num_train_epochs), hooks=hooks, max_steps=flags_obj.max_train_steps) tf.logging.info('Starting to evaluate.') # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, which # will iterate forever. Passing steps=flags_obj.max_train_steps allows the # eval (which is generally unimportant in those circumstances) to terminate. # Note that eval will run for max_train_steps each loop, regardless of the # global_step count. eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. export_dtype = flags_core.get_tf_dtype(flags_obj) if flags_obj.image_bytes_as_serving_input: input_receiver_fn = functools.partial(image_bytes_serving_input_fn, shape, dtype=export_dtype) else: input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size, dtype=export_dtype) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn, strip_default_attrs=True)
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError('dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=cifar_main.HEIGHT, width=cifar_main.WIDTH, num_channels=cifar_main.NUM_CHANNELS, num_classes=cifar_main.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj)) else: distribution_utils.undo_set_up_synthetic_data() input_fn = cifar_main.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) with strategy_scope: optimizer = keras_common.get_optimizer() model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES) model.compile(loss='categorical_crossentropy', optimizer=optimizer, run_eagerly=flags_obj.run_eagerly, metrics=['categorical_accuracy']) callbacks = keras_common.get_callbacks( learning_rate_schedule, cifar_main.NUM_IMAGES['train']) train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (cifar_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = keras_common.build_stats(history, eval_output, callbacks) return stats
def get_input_dataset(flags_obj, strategy): """Returns the test and train input datasets.""" dtype = flags_core.get_tf_dtype(flags_obj) use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy) batch_size = flags_obj.batch_size if use_dataset_fn: if batch_size % strategy.num_replicas_in_sync != 0: raise ValueError( 'Batch size must be divisible by number of replicas : {}'. format(strategy.num_replicas_in_sync)) # As auto rebatching is not supported in # `experimental_distribute_datasets_from_function()` API, which is # required when cloning dataset to multiple workers in eager mode, # we use per-replica batch size. batch_size = int(batch_size / strategy.num_replicas_in_sync) if flags_obj.use_synthetic_data: input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: input_fn = imagenet_preprocessing.input_fn def _train_dataset_fn(ctx=None): train_ds = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=batch_size, parse_record_fn=imagenet_preprocessing.parse_record, datasets_num_private_threads=flags_obj. datasets_num_private_threads, dtype=dtype, input_context=ctx, drop_remainder=True) return train_ds if strategy: if isinstance(strategy, tf.distribute.experimental.TPUStrategy): train_ds = strategy.experimental_distribute_datasets_from_function( _train_dataset_fn) else: train_ds = strategy.experimental_distribute_dataset( _train_dataset_fn()) else: train_ds = _train_dataset_fn() test_ds = None if not flags_obj.skip_eval: def _test_data_fn(ctx=None): test_ds = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=batch_size, parse_record_fn=imagenet_preprocessing.parse_record, dtype=dtype, input_context=ctx) return test_ds if strategy: if isinstance(strategy, tf.distribute.experimental.TPUStrategy): test_ds = strategy.experimental_distribute_datasets_from_function( _test_data_fn) else: test_ds = strategy.experimental_distribute_dataset( _test_data_fn()) else: test_ds = _test_data_fn() return train_ds, test_ds
def resnet_main(flags_obj, model_function, input_function, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. shape: list of ints representing the shape of the images used for training. This is only used if flags.export_dir is passed. """ # Using the Winograd non-fused algorithms provides a small performance boost. os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' if flags_obj.multi_gpu: validate_batch_size_for_multi_gpu(flags_obj.batch_size) # There are two steps required if using multi-GPU: (1) wrap the model_fn, # and (2) wrap the optimizer. The first happens here, and (2) happens # in the model_fn itself when the optimizer is defined. model_function = tf.contrib.estimator.replicate_model_fn( model_function, loss_reduction=tf.losses.Reduction.MEAN) # Create session config based on values of inter_op_parallelism_threads and # intra_op_parallelism_threads. Note that we default to having # allow_soft_placement = True, which is required for multi-GPU and not # harmful for other modes. session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) # Set up a RunConfig to save checkpoint and set session config. run_config = tf.estimator.RunConfig().replace( save_checkpoints_secs=1e9, session_config=session_config) classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'multi_gpu': flags_obj.multi_gpu, 'version': int(flags_obj.version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj) }) benchmark_logger = logger.config_benchmark_logger( flags_obj.benchmark_log_dir) benchmark_logger.log_run_info('resnet') train_hooks = hooks_helper.get_train_hooks( flags_obj.hooks, batch_size=flags_obj.batch_size, benchmark_log_dir=flags_obj.benchmark_log_dir) def input_fn_train(): return input_function(True, flags_obj.data_dir, flags_obj.batch_size, flags_obj.epochs_between_evals, flags_obj.num_parallel_calls, flags_obj.multi_gpu) def input_fn_eval(): return input_function(False, flags_obj.data_dir, flags_obj.batch_size, 1, flags_obj.num_parallel_calls, flags_obj.multi_gpu) total_training_cycle = (flags_obj.train_epochs // flags_obj.epochs_between_evals) for cycle_index in range(total_training_cycle): tf.logging.info('Starting a training cycle: %d/%d', cycle_index, total_training_cycle) classifier.train(input_fn=input_fn_train, hooks=train_hooks, max_steps=flags_obj.max_train_steps) tf.logging.info('Starting to evaluate.') # flags.max_train_steps is generally associated with testing and profiling. # As a result it is frequently called with synthetic data, which will # iterate forever. Passing steps=flags.max_train_steps allows the eval # (which is generally unimportant in those circumstances) to terminate. # Note that eval will run for max_train_steps each loop, regardless of the # global_step count. eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: warn_on_multi_gpu_export(flags_obj.multi_gpu) # Exports a saved model for the given classifier. input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. NotImplementedError: If some features are not currently supported. Returns: Dictionary of training and eval stats. """ try: _cudart = ctypes.CDLL('libcudart.so') except: _cudart = None keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) performance.set_mixed_precision_policy( flags_core.get_tf_dtype(flags_obj), flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. _ = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) # Current resnet_model.resnet50 input format is always channel-last. # We use keras_application mobilenet model which input format is depends on # the keras beckend image data format. # This use_keras_image_data_format flags indicates whether image preprocessor # output format should be same as the keras backend image data format or just # channel-last format. use_keras_image_data_format = flags_obj.keras_application_models preproccessing_type = imagenet_preprocessing if flags_obj.dataset == "imagenet" else cifar_preprocessing input_shape = (preproccessing_type.HEIGHT, preproccessing_type.WIDTH, \ preproccessing_type.NUM_CHANNELS) if use_keras_image_data_format: if tf.keras.backend.image_data_format() == 'channels_first': input_shape = (preproccessing_type.NUM_CHANNELS, preproccessing_type.HEIGHT, \ preproccessing_type.WIDTH) # pylint: disable=protected-access if flags_obj.use_synthetic_data: assert flags_obj.dataset == "imagenet", \ f"Expect to only work with ImageNet, but have {flags_obj.dataset}" distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=preproccessing_type.DEFAULT_IMAGE_SIZE, width=preproccessing_type.DEFAULT_IMAGE_SIZE, num_channels=preproccessing_type.NUM_CHANNELS, num_classes=preproccessing_type.NUM_CLASSES, use_keras_image_data_format=use_keras_image_data_format, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = preproccessing_type.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=preproccessing_type.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, training_dataset_cache=flags_obj.training_dataset_cache, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=preproccessing_type.get_parse_record_fn( use_keras_image_data_format=use_keras_image_data_format), dtype=dtype, drop_remainder=drop_remainder) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=preproccessing_type.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) steps_per_epoch = (preproccessing_type.NUM_IMAGES['train'] // flags_obj.batch_size) with strategy_scope: if flags_obj.optimizer == 'resnet50_default': optimizer = common.get_optimizer(lr_schedule) elif flags_obj.optimizer == 'mobilenet_default': initial_learning_rate = \ flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size optimizer = tf.keras.optimizers.SGD( learning_rate=tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay, decay_rate=flags_obj.lr_decay_factor, staircase=True), momentum=0.9) if flags_obj.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) # TODO(hongkuny): Remove trivial model usage and move it to benchmark. if flags_obj.use_trivial_model: model = trivial_model.trivial_model( preproccessing_type.NUM_CLASSES) elif flags_obj.model == 'resnet50_v1.5': resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers) model = resnet_model.resnet50( num_classes=preproccessing_type.NUM_CLASSES, input_shape=input_shape) elif flags_obj.model == 'mobilenet': # TODO(kimjaehong): Remove layers attribute when minimum TF version # support 2.0 layers by default. model = tf.keras.applications.mobilenet.MobileNet( weights=None, classes=preproccessing_type.NUM_CLASSES, input_shape=input_shape, layers=tf.keras.layers) elif flags_obj.keras_application_models: model_kfn = keras_app_models.get(flags_obj.model, None) if model_kfn is None: raise ValueError("No keras application model name %s" % flags_obj.model) model = model_kfn(weights=None, input_shape=input_shape, classes=preproccessing_type.NUM_CLASSES) if flags_obj.pretrained_filepath: model.load_weights(flags_obj.pretrained_filepath) if flags_obj.pruning_method == 'polynomial_decay': if dtype != tf.float32: raise NotImplementedError( 'Pruning is currently only supported on dtype=tf.float32.') pruning_params = { 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay( initial_sparsity=flags_obj.pruning_initial_sparsity, final_sparsity=flags_obj.pruning_final_sparsity, begin_step=flags_obj.pruning_begin_step, end_step=flags_obj.pruning_end_step, frequency=flags_obj.pruning_frequency), } model = tfmot.sparsity.keras.prune_low_magnitude( model, **pruning_params) elif flags_obj.pruning_method: raise NotImplementedError( 'Only polynomial_decay is currently supported.') model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks( steps_per_epoch=steps_per_epoch, pruning_method=flags_obj.pruning_method, enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export, model_dir=flags_obj.model_dir) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = (preproccessing_type.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() if _cudart: cuda_status = _cudart.cudaProfilerStart() else: cuda_status = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) if cuda_status == 0: _cudart.cudaProfilerStop() eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if flags_obj.pruning_method: model = tfmot.sparsity.keras.strip_pruning(model) if flags_obj.enable_checkpoint_and_export: if dtype == tf.bfloat16: logging.warning( 'Keras model.save does not support bfloat16 dtype.') else: # Keras model.save assumes a float32 input designature. export_path = os.path.join(flags_obj.model_dir, 'saved_model') model.save(export_path, include_optimizer=False) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def resnet_main(flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. """ # Using the Winograd non-fused algorithms provides a small performance boost. os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # Create session config based on values of inter_op_parallelism_threads and # intra_op_parallelism_threads. Note that we default to having # allow_soft_placement = True, which is required for multi-GPU and not # harmful for other modes. session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) if flags_core.get_num_gpus(flags_obj) == 0: distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0') elif flags_core.get_num_gpus(flags_obj) == 1: distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0') else: distribution = tf.contrib.distribute.MirroredStrategy( num_gpus=flags_core.get_num_gpus(flags_obj)) run_config = tf.estimator.RunConfig(train_distribute=distribution, session_config=session_config) classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj) }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, } benchmark_logger = logger.config_benchmark_logger(flags_obj) benchmark_logger.log_run_info('resnet', dataset_name, run_params) train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks, batch_size=flags_obj.batch_size) def input_fn_train(num_epochs): return input_function( mode="train", data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, num_gpus=flags_core.get_num_gpus(flags_obj), dtype=flags_core.get_tf_dtype(flags_obj)) def input_fn_eval(): return input_function(mode="validate", data_dir=flags_obj.data_dir, batch_size=per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1) def input_fn_pred(): return input_function(mode="predict", data_dir=flags_obj.data_dir, batch_size=per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1) # if flags_obj.predict_only: result = classifier.predict(input_fn=lambda: input_fn_pred()) predicted_values = np.stack([r["predictions"] for r in result], axis=0) #print(predicted_values) df = pd.DataFrame(predicted_values) df.to_csv("validate_result.txt") return # train if flags_obj.eval_only or not flags_obj.train_epochs: # If --eval_only is set, perform a single loop with zero train epochs. schedule, n_loops = [0], 1 else: # Compute the number of times to loop while training. All but the last # pass will train for `epochs_between_evals` epochs, while the last will # train for the number needed to reach `training_epochs`. For instance if # train_epochs = 25 and epochs_between_evals = 10 # schedule will be set to [10, 10, 5]. That is to say, the loop will: # Train for 10 epochs and then evaluate. # Train for another 10 epochs and then evaluate. # Train for a final 5 epochs (to reach 25 epochs) and then evaluate. n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals) schedule = [ flags_obj.epochs_between_evals for _ in range(int(n_loops)) ] schedule[-1] = flags_obj.train_epochs - sum( schedule[:-1]) # over counting. for cycle_index, num_train_epochs in enumerate(schedule): tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) if num_train_epochs: classifier.train(input_fn=lambda: input_fn_train(num_train_epochs), hooks=train_hooks, max_steps=flags_obj.max_train_steps) tf.logging.info('Starting to evaluate.') # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, which # will iterate forever. Passing steps=flags_obj.max_train_steps allows the # eval (which is generally unimportant in those circumstances) to terminate. # Note that eval will run for max_train_steps each loop, regardless of the # global_step count. eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=100) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['mse']): break # save model for serving if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
def net_main(flags_obj, model_function, input_function, net_data_configs, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. """ model_helpers.apply_clean(flags.FLAGS) is_metriclog = True if is_metriclog: metric_logfn = os.path.join(flags_obj.model_dir, 'log_metric.txt') metric_logf = open(metric_logfn, 'a') from tensorflow.contrib.memory_stats.ops import gen_memory_stats_ops max_memory_usage = gen_memory_stats_ops.max_bytes_in_use() # Using the Winograd non-fused algorithms provides a small performance boost. os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # Create session config based on values of inter_op_parallelism_threads and # intra_op_parallelism_threads. Note that we default to having # allow_soft_placement = True, which is required for multi-GPU and not # harmful for other modes. session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy, session_config=session_config) # initialize our model with all but the dense layer from pretrained resnet if flags_obj.pretrained_model_checkpoint_path is not None: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start='^(?!.*dense)') else: warm_start_settings = None classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, warm_start_from=warm_start_settings, params={ 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'loss_scale': flags_core.get_loss_scale(flags_obj), 'weight_decay': flags_obj.weight_decay, 'dtype': flags_core.get_tf_dtype(flags_obj), 'fine_tune': flags_obj.fine_tune, 'examples_per_epoch': flags_obj.examples_per_epoch, 'net_data_configs': net_data_configs }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, } dataset_name = net_data_configs['dataset_name'] if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('meshnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) def input_fn_train(num_epochs): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, num_gpus=flags_core.get_num_gpus(flags_obj), examples_per_epoch=flags_obj.examples_per_epoch, sg_settings=net_data_configs['sg_settings']) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, sg_settings=net_data_configs['sg_settings']) if flags_obj.eval_only or flags_obj.pred_ply or not flags_obj.train_epochs: # If --eval_only is set, perform a single loop with zero train epochs. schedule, n_loops = [0], 1 else: # Compute the number of times to loop while training. All but the last # pass will train for `epochs_between_evals` epochs, while the last will # train for the number needed to reach `training_epochs`. For instance if # train_epochs = 25 and epochs_between_evals = 10 # schedule will be set to [10, 10, 5]. That is to say, the loop will: # Train for 10 epochs and then evaluate. # Train for another 10 epochs and then evaluate. # Train for a final 5 epochs (to reach 25 epochs) and then evaluate. n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals) schedule = [ flags_obj.epochs_between_evals for _ in range(int(n_loops)) ] schedule[-1] = flags_obj.train_epochs - sum( schedule[:-1]) # over counting. classifier.train(input_fn=lambda: input_fn_train(1), hooks=train_hooks, max_steps=10) with tf.Session() as sess: max_memory_usage_v = sess.run(max_memory_usage) tf.logging.info('\n\nmemory usage: %0.3f G\n\n' % (max_memory_usage_v * 1.0 / 1e9)) best_acc, best_acc_checkpoint = load_saved_best(flags_obj.model_dir) for cycle_index, num_train_epochs in enumerate(schedule): tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) t0 = time.time() train_t = 0 if num_train_epochs: classifier.train(input_fn=lambda: input_fn_train(num_train_epochs), hooks=train_hooks, max_steps=flags_obj.max_train_steps) train_t = (time.time() - t0) / num_train_epochs tf.logging.info('Starting to evaluate.') # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, which # will iterate forever. Passing steps=flags_obj.max_train_steps allows the # eval (which is generally unimportant in those circumstances) to terminate. # Note that eval will run for max_train_steps each loop, regardless of the # global_step count. only_train = False and (not flags_obj.eval_only) and ( not flags_obj.pred_ply) if not only_train: t0 = time.time() eval_results = classifier.evaluate( input_fn=input_fn_eval, steps=flags_obj.max_train_steps, ) #checkpoint_path=best_acc_checkpoint) eval_t = time.time() - t0 if flags_obj.pred_ply: pred_generator = classifier.predict(input_fn=input_fn_eval) num_classes = net_data_configs['dset_metas'].num_classes gen_pred_ply(eval_results, pred_generator, flags_obj.model_dir, num_classes) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['accuracy']): break cur_is_best = '' if num_train_epochs and eval_results['accuracy'] > best_acc: best_acc = eval_results['accuracy'] save_cur_model_as_best_acc(flags_obj.model_dir, best_acc) cur_is_best = 'best' global_step = cur_global_step(flags_obj.model_dir) epoch = int(global_step / flags_obj.examples_per_epoch * flags_obj.num_gpus) ious_str = get_ious_str(eval_results['cm'], net_data_configs['dset_metas'], eval_results['mean_iou']) metric_logf.write( '\n{} train t:{:.1f} eval t:{:.1f} \teval acc:{:.3f} \tmean_iou:{:.3f} {} {}\n' .format(epoch, train_t, eval_t, eval_results['accuracy'], eval_results['mean_iou'], cur_is_best, ious_str)) metric_logf.flush() if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ dtype = flags_core.get_tf_dtype(flags_obj) # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) train_ds, test_ds = get_input_dataset(flags_obj, strategy) train_steps, train_epochs, eval_steps = get_num_train_iterations(flags_obj) time_callback = keras_utils.TimeHistory(flags_obj.batch_size, flags_obj.log_steps) strategy_scope = distribution_utils.get_strategy_scope(strategy) with strategy_scope: model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES, dtype=dtype, batch_size=flags_obj.batch_size) optimizer = tf.keras.optimizers.SGD( learning_rate=keras_common.BASE_LEARNING_RATE, momentum=0.9, nesterov=True) training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'training_accuracy', dtype=tf.float32) test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32) test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'test_accuracy', dtype=tf.float32) def train_step(train_ds_inputs): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs with tf.GradientTape() as tape: logits = model(images, training=True) prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss1 = tf.reduce_sum(prediction_loss) * ( 1.0 / flags_obj.batch_size) loss2 = (tf.reduce_sum(model.losses) / tf.distribute.get_strategy().num_replicas_in_sync) loss = loss1 + loss2 grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) training_accuracy.update_state(labels, logits) return loss if strategy: per_replica_losses = strategy.experimental_run_v2( step_fn, args=(train_ds_inputs, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) else: return step_fn(train_ds_inputs) def test_step(test_ds_inputs): """Evaluation StepFn.""" def step_fn(inputs): images, labels = inputs logits = model(images, training=False) loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(loss) * (1.0 / flags_obj.batch_size) test_loss.update_state(loss) test_accuracy.update_state(labels, logits) if strategy: strategy.experimental_run_v2(step_fn, args=(test_ds_inputs, )) else: step_fn(test_ds_inputs) if flags_obj.use_tf_function: train_step = tf.function(train_step) test_step = tf.function(test_step) time_callback.on_train_begin() for epoch in range(train_epochs): train_iter = iter(train_ds) total_loss = 0.0 training_accuracy.reset_states() for step in range(train_steps): optimizer.lr = keras_imagenet_main.learning_rate_schedule( epoch, step, train_steps, flags_obj.batch_size) time_callback.on_batch_begin(step + epoch * train_steps) total_loss += train_step(next(train_iter)) time_callback.on_batch_end(step + epoch * train_steps) train_loss = total_loss / train_steps logging.info('Training loss: %s, accuracy: %s%% at epoch: %d', train_loss.numpy(), training_accuracy.result().numpy(), epoch) if (not flags_obj.skip_eval and (epoch + 1) % flags_obj.epochs_between_evals == 0): test_loss.reset_states() test_accuracy.reset_states() test_iter = iter(test_ds) for _ in range(eval_steps): test_step(next(test_iter)) logging.info('Test loss: %s, accuracy: %s%% at epoch: %d', test_loss.result().numpy(), test_accuracy.result().numpy(), epoch) time_callback.on_train_end() eval_result = None train_result = None if not flags_obj.skip_eval: eval_result = [ test_loss.result().numpy(), test_accuracy.result().numpy() ] train_result = [ train_loss.numpy(), training_accuracy.result().numpy() ] stats = build_stats(train_result, eval_result, time_callback) return stats
def gen_estimator(period=None): resnet_size = int(flags_obj.resnet_size) data_format = flags_obj.data_format batch_size = flags_obj.batch_size resnet_version = int(flags_obj.resnet_version) loss_scale = flags_core.get_loss_scale(flags_obj) dtype_tf = flags_core.get_tf_dtype(flags_obj) num_epochs_per_decay = flags_obj.num_epochs_per_decay learning_rate_decay_factor = flags_obj.learning_rate_decay_factor end_learning_rate = flags_obj.end_learning_rate learning_rate_decay_type = flags_obj.learning_rate_decay_type weight_decay = flags_obj.weight_decay zero_gamma = flags_obj.zero_gamma lr_warmup_epochs = flags_obj.lr_warmup_epochs base_learning_rate = flags_obj.base_learning_rate use_resnet_d = flags_obj.use_resnet_d use_dropblock = flags_obj.use_dropblock dropblock_kp = [float(be) for be in flags_obj.dropblock_kp] label_smoothing = flags_obj.label_smoothing momentum = flags_obj.momentum bn_momentum = flags_obj.bn_momentum train_epochs = flags_obj.train_epochs piecewise_lr_boundary_epochs = [ int(be) for be in flags_obj.piecewise_lr_boundary_epochs ] piecewise_lr_decay_rates = [ float(dr) for dr in flags_obj.piecewise_lr_decay_rates ] use_ranking_loss = flags_obj.use_ranking_loss use_se_block = flags_obj.use_se_block use_sk_block = flags_obj.use_sk_block mixup_type = flags_obj.mixup_type dataset_name = flags_obj.dataset_name kd_temp = flags_obj.kd_temp no_downsample = flags_obj.no_downsample anti_alias_filter_size = flags_obj.anti_alias_filter_size anti_alias_type = flags_obj.anti_alias_type cls_loss_type = flags_obj.cls_loss_type logit_type = flags_obj.logit_type embedding_size = flags_obj.embedding_size pool_type = flags_obj.pool_type arc_s = flags_obj.arc_s arc_m = flags_obj.arc_m bl_alpha = flags_obj.bl_alpha bl_beta = flags_obj.bl_beta exp = None if install_hyperdash and flags_obj.use_hyperdash: exp = Experiment(flags_obj.model_dir.split("/")[-1]) resnet_size = exp.param("resnet_size", int(flags_obj.resnet_size)) batch_size = exp.param("batch_size", flags_obj.batch_size) exp.param("dtype", flags_obj.dtype) learning_rate_decay_type = exp.param( "learning_rate_decay_type", flags_obj.learning_rate_decay_type) weight_decay = exp.param("weight_decay", flags_obj.weight_decay) zero_gamma = exp.param("zero_gamma", flags_obj.zero_gamma) lr_warmup_epochs = exp.param("lr_warmup_epochs", flags_obj.lr_warmup_epochs) base_learning_rate = exp.param("base_learning_rate", flags_obj.base_learning_rate) use_dropblock = exp.param("use_dropblock", flags_obj.use_dropblock) dropblock_kp = exp.param( "dropblock_kp", [float(be) for be in flags_obj.dropblock_kp]) piecewise_lr_boundary_epochs = exp.param( "piecewise_lr_boundary_epochs", [int(be) for be in flags_obj.piecewise_lr_boundary_epochs]) piecewise_lr_decay_rates = exp.param( "piecewise_lr_decay_rates", [float(dr) for dr in flags_obj.piecewise_lr_decay_rates]) mixup_type = exp.param("mixup_type", flags_obj.mixup_type) dataset_name = exp.param("dataset_name", flags_obj.dataset_name) exp.param("autoaugment_type", flags_obj.autoaugment_type) classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, params={ 'resnet_size': resnet_size, 'data_format': data_format, 'batch_size': batch_size, 'resnet_version': resnet_version, 'loss_scale': loss_scale, 'dtype': dtype_tf, 'num_epochs_per_decay': num_epochs_per_decay, 'learning_rate_decay_factor': learning_rate_decay_factor, 'end_learning_rate': end_learning_rate, 'learning_rate_decay_type': learning_rate_decay_type, 'weight_decay': weight_decay, 'zero_gamma': zero_gamma, 'lr_warmup_epochs': lr_warmup_epochs, 'base_learning_rate': base_learning_rate, 'use_resnet_d': use_resnet_d, 'use_dropblock': use_dropblock, 'dropblock_kp': dropblock_kp, 'label_smoothing': label_smoothing, 'momentum': momentum, 'bn_momentum': bn_momentum, 'embedding_size': embedding_size, 'train_epochs': train_epochs, 'piecewise_lr_boundary_epochs': piecewise_lr_boundary_epochs, 'piecewise_lr_decay_rates': piecewise_lr_decay_rates, 'with_drawing_bbox': flags_obj.with_drawing_bbox, 'use_ranking_loss': use_ranking_loss, 'use_se_block': use_se_block, 'use_sk_block': use_sk_block, 'mixup_type': mixup_type, 'kd_temp': kd_temp, 'no_downsample': no_downsample, 'dataset_name': dataset_name, 'anti_alias_filter_size': anti_alias_filter_size, 'anti_alias_type': anti_alias_type, 'cls_loss_type': cls_loss_type, 'logit_type': logit_type, 'arc_s': arc_s, 'arc_m': arc_m, 'pool_type': pool_type, 'bl_alpha': bl_alpha, 'bl_beta': bl_beta, 'train_steps': total_train_steps, }) return classifier, exp
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=cifar_preprocessing.HEIGHT, width=cifar_preprocessing.WIDTH, num_channels=cifar_preprocessing.NUM_CHANNELS, num_classes=cifar_preprocessing.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj), drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = cifar_preprocessing.input_fn #train_input_dataset = input_fn( # is_training=True, # data_dir=flags_obj.data_dir, # batch_size=flags_obj.batch_size, # num_epochs=flags_obj.train_epochs, # parse_record_fn=cifar_preprocessing.parse_record, # datasets_num_private_threads=flags_obj.datasets_num_private_threads, # dtype=dtype, # # Setting drop_remainder to avoid the partial batch logic in normalization # # layer, which triggers tf.where and leads to extra memory copy of input # # sizes between host and GPU. # drop_remainder=(not flags_obj.enable_get_next_as_optional)) # eval_input_dataset = None # if not flags_obj.skip_eval: # eval_input_dataset = input_fn( # is_training=False, # data_dir=flags_obj.data_dir, # batch_size=flags_obj.batch_size, # num_epochs=flags_obj.train_epochs, # parse_record_fn=cifar_preprocessing.parse_record) (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 y_train = tf.keras.utils.to_categorical(y_train, num_classes) y_test = tf.keras.utils.to_categorical(y_test, num_classes) # optimizer = common.get_optimizer() opt = tf.keras.optimizers.SGD(learning_rate=0.1) logging.info(opt.__dict__) optimizer = SynchronousSGDOptimizer(opt, use_locking=True) optimizer._hyper = opt._hyper logging.info(optimizer.__dict__) model = Conv4_model(x_train, num_classes) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='categorical_crossentropy', optimizer=optimizer, metrics=(['accuracy']), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) else: model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=(['accuracy']), run_eagerly=flags_obj.run_eagerly) cluster_size = current_cluster_size() steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) steps_per_epoch = steps_per_epoch // cluster_size train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks(steps_per_epoch, current_rank(), cluster_size, learning_rate_schedule) callbacks.append(BroadcastGlobalVariablesCallback()) if flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) # validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None tf.compat.v1.logging.info(x_train.shape) history = model.fit(x_train, y_train, batch_size=flags_obj.batch_size, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=(x_test, y_test), validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate((x_test, y_test), steps=num_eval_steps, verbose=2) stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ config = keras_common.get_config_proto() # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # Eager is default in tf 2.0 and should not be toggled if not keras_common.is_v2_0(): if flags_obj.enable_eager: tf.compat.v1.enable_eager_execution(config=config) else: sess = tf.Session(config=config) tf.keras.backend.set_session(sess) # TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready. dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=dtype) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_main.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype) eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, dtype=dtype) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus) strategy_scope = keras_common.get_strategy_scope(strategy) with strategy_scope: optimizer = keras_common.get_optimizer() if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj)) model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES, dtype=dtype) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['sparse_categorical_accuracy']) time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=[ time_callback, lr_callback, tensorboard_callback ], validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) stats = keras_common.build_stats(history, eval_output, time_callback) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ config = keras_common.get_config_proto() # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # Eager is default in tf 2.0 and should not be toggled if not keras_common.is_v2_0(): if flags_obj.enable_eager: tf.compat.v1.enable_eager_execution(config=config) else: sess = tf.Session(config=config) tf.keras.backend.set_session(sess) # TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready. dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy( 'infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=dtype) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_main.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype) eval_input_dataset = input_fn(is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, dtype=dtype) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster()) strategy_scope = keras_common.get_strategy_scope(strategy) with strategy_scope: optimizer = keras_common.get_optimizer() if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj)) model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES, dtype=dtype) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['sparse_categorical_accuracy']) time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit( train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=[time_callback, lr_callback, tensorboard_callback], validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) stats = keras_common.build_stats(history, eval_output, time_callback) return stats
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads ) common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.config.list_physical_devices('GPU') else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional) strategy_scope = distribution_utils.get_strategy_scope(strategy) if flags_obj.use_synthetic_data: synthetic_util.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=cifar_preprocessing.HEIGHT, width=cifar_preprocessing.WIDTH, num_channels=cifar_preprocessing.NUM_CHANNELS, num_classes=cifar_preprocessing.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj), drop_remainder=True) else: synthetic_util.undo_set_up_synthetic_data() input_fn = cifar_preprocessing.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=cifar_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, # Setting drop_remainder to avoid the partial batch logic in normalization # layer, which triggers tf.where and leads to extra memory copy of input # sizes between host and GPU. drop_remainder=(not flags_obj.enable_get_next_as_optional)) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, parse_record_fn=cifar_preprocessing.parse_record) options = tf.data.Options() options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA eval_input_dataset = eval_input_dataset.with_options(options) steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) lr_schedule = 0.1 if flags_obj.use_tensor_lr: initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128 lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay( boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE), values=[initial_learning_rate] + list(p[0] * initial_learning_rate for p in LR_SCHEDULE)) with strategy_scope: optimizer = common.get_optimizer(lr_schedule) model = resnet_cifar_model.resnet56( classes=cifar_preprocessing.NUM_CLASSES) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) train_epochs = flags_obj.train_epochs callbacks = common.get_callbacks() if not flags_obj.use_tensor_lr: lr_callback = LearningRateBatchScheduler( schedule=learning_rate_schedule, batch_size=flags_obj.batch_size, steps_per_epoch=steps_per_epoch) callbacks.append(lr_callback) # if mutliple epochs, ignore the train_steps flag. if train_epochs <= 1 and flags_obj.train_steps: steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) train_epochs = 1 num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet Cifar-10 training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ if flags_obj.enable_eager: tf.enable_eager_execution() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError('dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) if flags_obj.use_synthetic_data: input_fn = keras_common.get_synth_input_fn( height=cifar_main.HEIGHT, width=cifar_main.WIDTH, num_channels=cifar_main.NUM_CHANNELS, num_classes=cifar_main.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj)) else: input_fn = cifar_main.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) strategy = distribution_utils.get_distribution_strategy( num_gpus=flags_obj.num_gpus, turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy) strategy_scope = keras_common.get_strategy_scope(strategy) with strategy_scope: optimizer = keras_common.get_optimizer() model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES) model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['categorical_accuracy']) time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( learning_rate_schedule, cifar_main.NUM_IMAGES['train']) train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (cifar_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=[ time_callback, lr_callback, tensorboard_callback ], validation_steps=num_eval_steps, validation_data=validation_data, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=1) stats = keras_common.build_stats(history, eval_output, time_callback) return stats
def resnet_main( flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. """ print("RESNET MAIN") model_helpers.apply_clean(flags.FLAGS) # Ensures flag override logic is only executed if explicitly triggered. if flags_obj.tf_gpu_thread_mode: override_flags_and_set_envars_for_gpu_thread_pool(flags_obj) # Creates session config. allow_soft_placement = True, is required for # multi-GPU and is not harmful for other modes. session_config = tf.ConfigProto(allow_soft_placement=True) run_config = tf.estimator.RunConfig( session_config=session_config, save_checkpoints_secs=60*60*24) # Initializes model with all but the dense layer from pretrained ResNet. if flags_obj.pretrained_model_checkpoint_path is not None: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start='^(?!.*dense)') else: warm_start_settings = None classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, warm_start_from=warm_start_settings, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj), 'fine_tune': flags_obj.fine_tune }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, } def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_tf_dtype(flags_obj)) schedule, n_loops = [0], 1 if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. export_dtype = flags_core.get_tf_dtype(flags_obj) if flags_obj.image_bytes_as_serving_input: input_receiver_fn = functools.partial( image_bytes_serving_input_fn, shape, dtype=export_dtype) else: input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size, dtype=export_dtype) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn, strip_default_attrs=True)
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) # This only affects GPU. common.set_cudnn_batchnorm_mode() # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.config.list_physical_devices('GPU') else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) logging.info( 'Training %d epochs, each epoch has %d steps, ' 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory( flags_obj.batch_size, flags_obj.log_steps, logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None) with distribution_utils.get_strategy_scope(strategy): runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, per_epoch_steps) eval_interval = flags_obj.epochs_between_evals * per_epoch_steps checkpoint_interval = (per_epoch_steps if flags_obj.enable_checkpoint_and_export else None) summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=flags_obj.model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval) resnet_controller = controller.Controller( strategy, runnable.train, runnable.evaluate, global_step=runnable.global_step, steps_per_loop=steps_per_loop, train_steps=per_epoch_steps * train_epochs, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, eval_steps=eval_steps, eval_interval=eval_interval) time_callback.on_train_begin() resnet_controller.train(evaluate=not flags_obj.skip_eval) time_callback.on_train_end() stats = build_stats(runnable, time_callback) return stats
def use_float16(): return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
def resnet_main(flags_obj, model_function, input_function, dataset_name, percent, model_class, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. Returns: Dict of results of the run. Contains the keys `eval_results` and `train_hooks`. `eval_results` contains accuracy (top_1) and accuracy_top_5. `train_hooks` is a list the instances of hooks used during training. """ model_helpers.apply_clean(flags.FLAGS) # Ensures flag override logic is only executed if explicitly triggered. if flags_obj.tf_gpu_thread_mode: override_flags_and_set_envars_for_gpu_thread_pool(flags_obj) # Configures cluster spec for distribution strategy. num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) # Creates session config. allow_soft_placement = True, is required for # multi-GPU and is not harmful for other modes. session_config = tf.compat.v1.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_core.get_num_gpus(flags_obj), num_workers=num_workers, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) # Creates a `RunConfig` that checkpoints every 24 hours which essentially # results in checkpoints determined only by `epochs_between_evals`. run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy, session_config=session_config, save_checkpoints_secs=60 * 60 * 24, save_checkpoints_steps=None) # Initializes model with all but the dense layer from pretrained ResNet. if flags_obj.pretrained_model_checkpoint_path is not None: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start='^(?!.*dense)') else: warm_start_settings = None params = { 'resnet_size': int(flags_obj.resnet_size), 'data_format': 'channels_last', 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj, default_for_fp16=128), 'dtype': flags_core.get_tf_dtype(flags_obj), 'fine_tune': flags_obj.fine_tune, 'num_workers': num_workers, 'adv_train': False, 'attack': False, } classifier = tf.compat.v1.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, warm_start_from=warm_start_settings, params=params) params['adv_train'] = True classifier_adv = tf.compat.v1.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, warm_start_from=warm_start_settings, params=params) params['adv_train'] = False params['attack'] = True classifier_attack = tf.compat.v1.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, warm_start_from=warm_start_settings, params=params) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, 'num_workers': num_workers, } if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('resnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) def input_fn_train(num_epochs, input_context=None): return input_function( is_training=True, percent=percent, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_replica_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, dtype=flags_core.get_tf_dtype(flags_obj), datasets_num_private_threads=flags_obj. datasets_num_private_threads, input_context=input_context) def input_fn_eval(): return input_function( is_training=False, percent=0, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_replica_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_tf_dtype(flags_obj)) def input_fn_eval_attack(): return input_function( is_training=False, percent=100, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_replica_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_tf_dtype(flags_obj)) train_epochs = (0 if flags_obj.eval_only or not flags_obj.train_epochs else flags_obj.train_epochs) tf.compat.v1.logging.info(tf.global_variables()) use_train_and_evaluate = flags_obj.use_train_and_evaluate or num_workers > 1 if use_train_and_evaluate: train_spec = tf.estimator.TrainSpec( input_fn=lambda input_context=None: input_fn_train( train_epochs, input_context=input_context), hooks=train_hooks, max_steps=flags_obj.max_train_steps) eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval) tf.compat.v1.logging.info('Starting to train and evaluate.') tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec) # tf.estimator.train_and_evalute doesn't return anything in multi-worker # case. eval_results = {} else: if train_epochs == 0: # If --eval_only is set, perform a single loop with zero train epochs. schedule, n_loops = [0], 1 else: # Compute the number of times to loop while training. All but the last # pass will train for `epochs_between_evals` epochs, while the last will # train for the number needed to reach `training_epochs`. For instance if # train_epochs = 25 and epochs_between_evals = 10 # schedule will be set to [10, 10, 5]. That is to say, the loop will: # Train for 10 epochs and then evaluate. # Train for another 10 epochs and then evaluate. # Train for a final 5 epochs (to reach 25 epochs) and then evaluate. n_loops = math.ceil(train_epochs / flags_obj.epochs_between_evals) schedule = [ flags_obj.epochs_between_evals for _ in range(int(n_loops)) ] schedule[-1] = train_epochs - sum(schedule[:-1]) # over counting. for cycle_index, num_train_epochs in enumerate(schedule): tf.compat.v1.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) if num_train_epochs: # Since we are calling classifier.train immediately in each loop, the # value of num_train_epochs in the lambda function will not be changed # before it is used. So it is safe to ignore the pylint error here # pylint: disable=cell-var-from-loop if flags_obj.adv_train: classifier_adv.train( input_fn=lambda input_context=None: input_fn_train( num_train_epochs, input_context=input_context), hooks=train_hooks, max_steps=flags_obj.max_train_steps) else: classifier.train( input_fn=lambda input_context=None: input_fn_train( num_train_epochs, input_context=input_context), hooks=train_hooks, max_steps=flags_obj.max_train_steps) # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, # which will iterate forever. Passing steps=flags_obj.max_train_steps # allows the eval (which is generally unimportant in those circumstances) # to terminate. Note that eval will run for max_train_steps each loop, # regardless of the global_step count. tf.compat.v1.logging.info('Starting to evaluate clean.') eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) tf.compat.v1.logging.info('Starting to evaluate adv.') eval_results_adv = classifier_adv.evaluate( input_fn=input_fn_eval, steps=flags_obj.max_train_steps) tf.compat.v1.logging.info('Starting to evaluate attack.') eval_results_attack = classifier_attack.evaluate( input_fn=input_fn_eval_attack, steps=flags_obj.max_train_steps) print( '########################## clean #############################' ) benchmark_logger.log_evaluation_result(eval_results) print( '########################## adv #############################') benchmark_logger.log_evaluation_result(eval_results_adv) print( '########################## attack #############################' ) benchmark_logger.log_evaluation_result(eval_results_attack) if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. export_dtype = flags_core.get_tf_dtype(flags_obj) if flags_obj.image_bytes_as_serving_input: input_receiver_fn = functools.partial(image_bytes_serving_input_fn, shape, dtype=export_dtype) else: input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size, dtype=export_dtype) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn, strip_default_attrs=True) stats = {} stats['eval_results'] = eval_results stats['eval_atttack_results'] = eval_results_attack stats['eval_adv_results'] = eval_results_adv stats['train_hooks'] = train_hooks return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config( enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: common.set_gpu_thread_mode_and_count(flags_obj) if flags_obj.data_delay_prefetch: common.data_delay_prefetch() common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # Configures cluster spec for distribution strategy. num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=num_workers, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) if strategy: # flags_obj.enable_get_next_as_optional controls whether enabling # get_next_as_optional behavior in DistributedIterator. If true, last # partial batch can be supported. strategy.extended.experimental_enable_get_next_as_optional = ( flags_obj.enable_get_next_as_optional ) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_preprocessing.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder, tf_data_experimental_slack=flags_obj.tf_data_experimental_slack, ) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=imagenet_preprocessing.parse_record, dtype=dtype, drop_remainder=drop_remainder) lr_schedule = 0.1 if flags_obj.use_tensor_lr: lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=LR_SCHEDULE[0][1], boundaries=list(p[1] for p in LR_SCHEDULE[1:]), multipliers=list(p[0] for p in LR_SCHEDULE), compute_lr_on_cpu=True) with strategy_scope: optimizer = common.get_optimizer(lr_schedule) if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) if flags_obj.use_trivial_model: model = trivial_model.trivial_model( imagenet_preprocessing.NUM_CLASSES, dtype) else: model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype) # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # a valid arg for this model. Also remove as a valid flag. if flags_obj.force_v2_in_keras_compile is not None: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly, experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) else: model.compile( loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly) callbacks = common.get_callbacks( learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train']) train_steps = ( imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = ( imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). if flags_obj.set_learning_phase_to_train: # TODO(haoyuzhang): Understand slowdown of setting learning phase when # not using distribution strategy. tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None if not strategy and flags_obj.explicit_gpu_placement: # TODO(b/135607227): Add device scope automatically in Keras training loop # when not using distribition strategy. no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device.__enter__() history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps//15, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=1) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=1) if not strategy and flags_obj.explicit_gpu_placement: no_dist_strat_device.__exit__() stats = common.build_stats(history, eval_output, callbacks) return stats
def __init__(self, flags_obj): """Init function of TransformerMain. Args: flags_obj: Object containing parsed flag values, i.e., FLAGS. Raises: ValueError: if not using static batch for input data on TPU. """ self.flags_obj = flags_obj self.predict_model = None # Add flag-defined parameters to params object num_gpus = flags_core.get_num_gpus(flags_obj) self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus) params["num_gpus"] = num_gpus params["use_ctl"] = flags_obj.use_ctl params["data_dir"] = flags_obj.data_dir params["model_dir"] = flags_obj.model_dir params["static_batch"] = flags_obj.static_batch params["max_length"] = flags_obj.max_length params["decode_batch_size"] = flags_obj.decode_batch_size params["decode_max_length"] = flags_obj.decode_max_length params["padded_decode"] = flags_obj.padded_decode params["num_parallel_calls"] = ( flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE) params["use_synthetic_data"] = flags_obj.use_synthetic_data params["batch_size"] = flags_obj.batch_size or params["default_batch_size"] params["repeat_dataset"] = None params["dtype"] = flags_core.get_tf_dtype(flags_obj) params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training if params["dtype"] == tf.float16: # TODO(reedwm): It's pretty ugly to set the global policy in a constructor # like this. What if multiple instances of TransformerTask are created? # We should have a better way in the tf.keras.mixed_precision API of doing # this. loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16="dynamic") policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( "mixed_float16", loss_scale=loss_scale) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) self.distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=num_gpus, tpu_address=flags_obj.tpu or "") if self.use_tpu: params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync if not params["static_batch"]: raise ValueError("TPU requires static batch for input data.") else: logging.info("Running transformer with num_gpus =", num_gpus) if self.distribution_strategy: logging.info("For training, using distribution strategy: ", self.distribution_strategy) else: logging.info("Not using any distribution strategy.")
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. """ if flags_obj.enable_eager: tf.enable_eager_execution() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError('dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # pylint: disable=protected-access if flags_obj.use_synthetic_data: input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj)) else: input_fn = imagenet_main.input_fn train_input_dataset = input_fn(is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) eval_input_dataset = input_fn(is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) strategy = distribution_utils.get_distribution_strategy( num_gpus=flags_obj.num_gpus, turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy) strategy_scope = keras_common.get_strategy_scope(strategy) with strategy_scope: optimizer = keras_common.get_optimizer() model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['sparse_categorical_accuracy']) time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=[ time_callback, lr_callback, tensorboard_callback ], validation_steps=num_eval_steps, validation_data=validation_data, verbose=1) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=1) stats = keras_common.build_stats(history, eval_output, time_callback) return stats
def resnet_main( flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. """ model_helpers.apply_clean(flags.FLAGS) # Using the Winograd non-fused algorithms provides a small performance boost. os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # Create session config based on values of inter_op_parallelism_threads and # intra_op_parallelism_threads. Note that we default to having # allow_soft_placement = True, which is required for multi-GPU and not # harmful for other modes. session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) run_config = tf.estimator.RunConfig( train_distribute=distribution_strategy, session_config=session_config) classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj) }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, } if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('resnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks( flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) def input_fn_train(): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=flags_obj.epochs_between_evals, num_gpus=flags_core.get_num_gpus(flags_obj)) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1) total_training_cycle = (flags_obj.train_epochs // flags_obj.epochs_between_evals) for cycle_index in range(total_training_cycle): tf.logging.info('Starting a training cycle: %d/%d', cycle_index, total_training_cycle) classifier.train(input_fn=input_fn_train, hooks=train_hooks, max_steps=flags_obj.max_train_steps) tf.logging.info('Starting to evaluate.') # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, which # will iterate forever. Passing steps=flags_obj.max_train_steps allows the # eval (which is generally unimportant in those circumstances) to terminate. # Note that eval will run for max_train_steps each loop, regardless of the # global_step count. eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold( flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # Eager is default in tf 2.0 and should not be toggled if keras_common.is_v2_0(): keras_common.set_config_v2() else: config = keras_common.get_config_proto_v1() if flags_obj.enable_eager: tf.compat.v1.enable_eager_execution(config=config) else: sess = tf.Session(config=config) tf.keras.backend.set_session(sess) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_common.set_gpu_thread_mode_and_count(flags_obj) if flags_obj.data_prefetch_with_slack: keras_common.data_prefetch_with_slack() keras_common.set_cudnn_batchnorm_mode() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_main.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, dtype=dtype, drop_remainder=drop_remainder) lr_schedule = 0.1 if flags_obj.use_tensor_lr: lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_main.NUM_IMAGES['train'], warmup_epochs=LR_SCHEDULE[0][1], boundaries=list(p[1] for p in LR_SCHEDULE[1:]), multipliers=list(p[0] for p in LR_SCHEDULE), compute_lr_on_cpu=True) with strategy_scope: optimizer = keras_common.get_optimizer(lr_schedule) if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj)) if flags_obj.enable_xla and not flags_obj.enable_eager: # TODO(b/129861005): Fix OOM issue in eager mode when setting # `batch_size` in keras.Input layer. if strategy and strategy.num_replicas_in_sync > 1: # TODO(b/129791381): Specify `input_layer_batch_size` value in # DistributionStrategy multi-replica case. input_layer_batch_size = None else: input_layer_batch_size = flags_obj.batch_size else: input_layer_batch_size = None if flags_obj.use_trivial_model: model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES, dtype) else: model = resnet_model.resnet50( num_classes=imagenet_main.NUM_CLASSES, dtype=dtype, batch_size=input_layer_batch_size) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None), cloning=flags_obj.clone_model_in_keras_dist_strat) callbacks = keras_common.get_callbacks( learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) stats = keras_common.build_stats(history, eval_output, callbacks) return stats
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # Eager is default in tf 2.0 and should not be toggled if keras_common.is_v2_0(): keras_common.set_config_v2() else: config = keras_common.get_config_proto_v1() if flags_obj.enable_eager: tf.compat.v1.enable_eager_execution(config=config) else: sess = tf.Session(config=config) tf.keras.backend.set_session(sess) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_common.set_gpu_thread_mode_and_count(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'float16': policy = tf.keras.mixed_precision.experimental.Policy( 'infer_float32_vars') tf.keras.mixed_precision.experimental.set_policy(policy) data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster()) strategy_scope = distribution_utils.get_strategy_scope(strategy) # pylint: disable=protected-access if flags_obj.use_synthetic_data: distribution_utils.set_up_synthetic_data() input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=dtype, drop_remainder=True) else: distribution_utils.undo_set_up_synthetic_data() input_fn = imagenet_main.input_fn # When `enable_xla` is True, we always drop the remainder of the batches # in the dataset, as XLA-GPU doesn't support dynamic shapes. drop_remainder = flags_obj.enable_xla train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, datasets_num_private_threads=flags_obj.datasets_num_private_threads, dtype=dtype, drop_remainder=drop_remainder) eval_input_dataset = None if not flags_obj.skip_eval: eval_input_dataset = input_fn(is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras, dtype=dtype, drop_remainder=drop_remainder) with strategy_scope: optimizer = keras_common.get_optimizer() if dtype == 'float16': # TODO(reedwm): Remove manually wrapping optimizer once mixed precision # can be enabled with a single line of code. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=flags_core.get_loss_scale(flags_obj)) if flags_obj.enable_xla and not flags_obj.enable_eager: # TODO(b/129861005): Fix OOM issue in eager mode when setting # `batch_size` in keras.Input layer. if strategy and strategy.num_replicas_in_sync > 1: # TODO(b/129791381): Specify `input_layer_batch_size` value in # DistributionStrategy multi-replica case. input_layer_batch_size = None else: input_layer_batch_size = flags_obj.batch_size else: input_layer_batch_size = None if flags_obj.use_trivial_model: model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES) else: model = resnet_model.resnet50( num_classes=imagenet_main.NUM_CLASSES, dtype=dtype, batch_size=input_layer_batch_size) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['sparse_categorical_accuracy']) callbacks = keras_common.get_callbacks(learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=validation_data, validation_freq=flags_obj.epochs_between_evals, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) stats = keras_common.build_stats(history, eval_output, callbacks) return stats
def resnet_main( flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. """ # Using the Winograd non-fused algorithms provides a small performance boost. os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # Create session config based on values of inter_op_parallelism_threads and # intra_op_parallelism_threads. Note that we default to having # allow_soft_placement = True, which is required for multi-GPU and not # harmful for other modes. session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) if flags_core.get_num_gpus(flags_obj) == 0: distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0') elif flags_core.get_num_gpus(flags_obj) == 1: distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0') else: distribution = tf.contrib.distribute.MirroredStrategy( num_gpus=flags_core.get_num_gpus(flags_obj) ) run_config = tf.estimator.RunConfig(train_distribute=distribution, session_config=session_config) classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj) }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, } benchmark_logger = logger.config_benchmark_logger(flags_obj) benchmark_logger.log_run_info('resnet', dataset_name, run_params) train_hooks = hooks_helper.get_train_hooks( flags_obj.hooks, batch_size=flags_obj.batch_size) def input_fn_train(): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=flags_obj.epochs_between_evals) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1) total_training_cycle = (flags_obj.train_epochs // flags_obj.epochs_between_evals) for cycle_index in range(total_training_cycle): tf.logging.info('Starting a training cycle: %d/%d', cycle_index, total_training_cycle) classifier.train(input_fn=input_fn_train, hooks=train_hooks, max_steps=flags_obj.max_train_steps) tf.logging.info('Starting to evaluate.') # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, which # will iterate forever. Passing steps=flags_obj.max_train_steps allows the # eval (which is generally unimportant in those circumstances) to terminate. # Note that eval will run for max_train_steps each loop, regardless of the # global_step count. eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold( flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, enable_xla=flags_obj.enable_xla) dtype = flags_core.get_tf_dtype(flags_obj) if dtype == tf.bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) # TODO(anj-s): Set data_format without using Keras. data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, num_workers=distribution_utils.configure_cluster(), all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) train_ds, test_ds = get_input_dataset(flags_obj, strategy) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps) logging.info( "Training %d epochs, each epoch has %d steps, " "total steps: %d; Eval %d steps", train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory(flags_obj.batch_size, flags_obj.log_steps) with distribution_utils.get_strategy_scope(strategy): model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES, batch_size=flags_obj.batch_size, use_l2_regularizer=not flags_obj.single_l2_loss_op) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) optimizer = common.get_optimizer(lr_schedule) if flags_obj.fp16_implementation == 'graph_rewrite': if not flags_obj.use_tf_function: raise ValueError( '--fp16_implementation=graph_rewrite requires ' '--use_tf_function to be true') loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale) train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'training_accuracy', dtype=tf.float32) test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32) test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'test_accuracy', dtype=tf.float32) trainable_variables = model.trainable_variables def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs with tf.GradientTape() as tape: logits = model(images, training=True) prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(prediction_loss) * (1.0 / flags_obj.batch_size) num_replicas = tf.distribute.get_strategy( ).num_replicas_in_sync if flags_obj.single_l2_loss_op: filtered_variables = [ tf.reshape(v, (-1, )) for v in trainable_variables if 'bn' not in v.name ] l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) loss += (l2_loss / num_replicas) else: loss += (tf.reduce_sum(model.losses) / num_replicas) # Scale the loss if flags_obj.dtype == "fp16": loss = optimizer.get_scaled_loss(loss) grads = tape.gradient(loss, trainable_variables) # Unscale the grads if flags_obj.dtype == "fp16": grads = optimizer.get_unscaled_gradients(grads) optimizer.apply_gradients(zip(grads, trainable_variables)) train_loss.update_state(loss) training_accuracy.update_state(labels, logits) @tf.function def train_steps(iterator, steps): """Performs distributed training steps in a loop.""" for _ in tf.range(steps): strategy.experimental_run_v2(step_fn, args=(next(iterator), )) def train_single_step(iterator): if strategy: strategy.experimental_run_v2(step_fn, args=(next(iterator), )) else: return step_fn(next(iterator)) def test_step(iterator): """Evaluation StepFn.""" def step_fn(inputs): images, labels = inputs logits = model(images, training=False) loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(loss) * (1.0 / flags_obj.batch_size) test_loss.update_state(loss) test_accuracy.update_state(labels, logits) if strategy: strategy.experimental_run_v2(step_fn, args=(next(iterator), )) else: step_fn(next(iterator)) if flags_obj.use_tf_function: train_single_step = tf.function(train_single_step) test_step = tf.function(test_step) train_iter = iter(train_ds) time_callback.on_train_begin() for epoch in range(train_epochs): train_loss.reset_states() training_accuracy.reset_states() steps_in_current_epoch = 0 while steps_in_current_epoch < per_epoch_steps: time_callback.on_batch_begin(steps_in_current_epoch + epoch * per_epoch_steps) steps = _steps_to_run(steps_in_current_epoch, per_epoch_steps, steps_per_loop) if steps == 1: train_single_step(train_iter) else: # Converts steps to a Tensor to avoid tf.function retracing. train_steps(train_iter, tf.convert_to_tensor(steps, dtype=tf.int32)) time_callback.on_batch_end(steps_in_current_epoch + epoch * per_epoch_steps) steps_in_current_epoch += steps logging.info('Training loss: %s, accuracy: %s at epoch %d', train_loss.result().numpy(), training_accuracy.result().numpy(), epoch + 1) if (not flags_obj.skip_eval and (epoch + 1) % flags_obj.epochs_between_evals == 0): test_loss.reset_states() test_accuracy.reset_states() test_iter = iter(test_ds) for _ in range(eval_steps): test_step(test_iter) logging.info('Test loss: %s, accuracy: %s%% at epoch: %d', test_loss.result().numpy(), test_accuracy.result().numpy(), epoch + 1) time_callback.on_train_end() eval_result = None train_result = None if not flags_obj.skip_eval: eval_result = [ test_loss.result().numpy(), test_accuracy.result().numpy() ] train_result = [ train_loss.result().numpy(), training_accuracy.result().numpy() ] stats = build_stats(train_result, eval_result, time_callback) return stats
def resnet_main(flags_obj, model_function, input_function, dataset_name, shape=None, num_images=None, zeroshot_eval=False): model_helpers.apply_clean(flags.FLAGS) # Using the Winograd non-fused algorithms provides a small performance boost. os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # Create session config based on values of inter_op_parallelism_threads and # intra_op_parallelism_threads. Note that we default to having # allow_soft_placement = True, which is required for multi-GPU and not # harmful for other modes. session_config = config_utils.get_session_config(flags_obj) run_config = config_utils.get_run_config(flags_obj, flags_core, session_config, num_images['train']) def gen_estimator(period=None): resnet_size = int(flags_obj.resnet_size) data_format = flags_obj.data_format batch_size = flags_obj.batch_size resnet_version = int(flags_obj.resnet_version) loss_scale = flags_core.get_loss_scale(flags_obj) dtype_tf = flags_core.get_tf_dtype(flags_obj) num_epochs_per_decay = flags_obj.num_epochs_per_decay learning_rate_decay_factor = flags_obj.learning_rate_decay_factor end_learning_rate = flags_obj.end_learning_rate learning_rate_decay_type = flags_obj.learning_rate_decay_type weight_decay = flags_obj.weight_decay zero_gamma = flags_obj.zero_gamma lr_warmup_epochs = flags_obj.lr_warmup_epochs base_learning_rate = flags_obj.base_learning_rate use_resnet_d = flags_obj.use_resnet_d use_dropblock = flags_obj.use_dropblock dropblock_kp = [float(be) for be in flags_obj.dropblock_kp] label_smoothing = flags_obj.label_smoothing momentum = flags_obj.momentum bn_momentum = flags_obj.bn_momentum train_epochs = flags_obj.train_epochs piecewise_lr_boundary_epochs = [ int(be) for be in flags_obj.piecewise_lr_boundary_epochs ] piecewise_lr_decay_rates = [ float(dr) for dr in flags_obj.piecewise_lr_decay_rates ] use_ranking_loss = flags_obj.use_ranking_loss use_se_block = flags_obj.use_se_block use_sk_block = flags_obj.use_sk_block mixup_type = flags_obj.mixup_type dataset_name = flags_obj.dataset_name kd_temp = flags_obj.kd_temp no_downsample = flags_obj.no_downsample anti_alias_filter_size = flags_obj.anti_alias_filter_size anti_alias_type = flags_obj.anti_alias_type cls_loss_type = flags_obj.cls_loss_type logit_type = flags_obj.logit_type embedding_size = flags_obj.embedding_size pool_type = flags_obj.pool_type arc_s = flags_obj.arc_s arc_m = flags_obj.arc_m bl_alpha = flags_obj.bl_alpha bl_beta = flags_obj.bl_beta exp = None if install_hyperdash and flags_obj.use_hyperdash: exp = Experiment(flags_obj.model_dir.split("/")[-1]) resnet_size = exp.param("resnet_size", int(flags_obj.resnet_size)) batch_size = exp.param("batch_size", flags_obj.batch_size) exp.param("dtype", flags_obj.dtype) learning_rate_decay_type = exp.param( "learning_rate_decay_type", flags_obj.learning_rate_decay_type) weight_decay = exp.param("weight_decay", flags_obj.weight_decay) zero_gamma = exp.param("zero_gamma", flags_obj.zero_gamma) lr_warmup_epochs = exp.param("lr_warmup_epochs", flags_obj.lr_warmup_epochs) base_learning_rate = exp.param("base_learning_rate", flags_obj.base_learning_rate) use_dropblock = exp.param("use_dropblock", flags_obj.use_dropblock) dropblock_kp = exp.param( "dropblock_kp", [float(be) for be in flags_obj.dropblock_kp]) piecewise_lr_boundary_epochs = exp.param( "piecewise_lr_boundary_epochs", [int(be) for be in flags_obj.piecewise_lr_boundary_epochs]) piecewise_lr_decay_rates = exp.param( "piecewise_lr_decay_rates", [float(dr) for dr in flags_obj.piecewise_lr_decay_rates]) mixup_type = exp.param("mixup_type", flags_obj.mixup_type) dataset_name = exp.param("dataset_name", flags_obj.dataset_name) exp.param("autoaugment_type", flags_obj.autoaugment_type) classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, params={ 'resnet_size': resnet_size, 'data_format': data_format, 'batch_size': batch_size, 'resnet_version': resnet_version, 'loss_scale': loss_scale, 'dtype': dtype_tf, 'num_epochs_per_decay': num_epochs_per_decay, 'learning_rate_decay_factor': learning_rate_decay_factor, 'end_learning_rate': end_learning_rate, 'learning_rate_decay_type': learning_rate_decay_type, 'weight_decay': weight_decay, 'zero_gamma': zero_gamma, 'lr_warmup_epochs': lr_warmup_epochs, 'base_learning_rate': base_learning_rate, 'use_resnet_d': use_resnet_d, 'use_dropblock': use_dropblock, 'dropblock_kp': dropblock_kp, 'label_smoothing': label_smoothing, 'momentum': momentum, 'bn_momentum': bn_momentum, 'embedding_size': embedding_size, 'train_epochs': train_epochs, 'piecewise_lr_boundary_epochs': piecewise_lr_boundary_epochs, 'piecewise_lr_decay_rates': piecewise_lr_decay_rates, 'with_drawing_bbox': flags_obj.with_drawing_bbox, 'use_ranking_loss': use_ranking_loss, 'use_se_block': use_se_block, 'use_sk_block': use_sk_block, 'mixup_type': mixup_type, 'kd_temp': kd_temp, 'no_downsample': no_downsample, 'dataset_name': dataset_name, 'anti_alias_filter_size': anti_alias_filter_size, 'anti_alias_type': anti_alias_type, 'cls_loss_type': cls_loss_type, 'logit_type': logit_type, 'arc_s': arc_s, 'arc_m': arc_m, 'pool_type': pool_type, 'bl_alpha': bl_alpha, 'bl_beta': bl_beta, 'train_steps': total_train_steps, }) return classifier, exp run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, } if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('resnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) def input_fn_train(num_epochs): return input_function(is_training=True, use_random_crop=flags_obj.training_random_crop, num_epochs=num_epochs, flags_obj=flags_obj) def input_fn_eval(): return input_function(is_training=False, use_random_crop=False, num_epochs=1, flags_obj=flags_obj) ckpt_keeper = checkpoint_utils.CheckpointKeeper( save_dir=flags_obj.model_dir, num_to_keep=flags_obj.num_best_ckpt_to_keep, keep_epoch=flags_obj.keep_ckpt_every_eval, maximize=True) if zeroshot_eval: dataset = data_config.get_config(dataset_name) model = model_fns.Model( int(flags_obj.resnet_size), flags_obj.data_format, resnet_version=int(flags_obj.resnet_version), num_classes=dataset.num_classes, zero_gamma=flags_obj.zero_gamma, use_se_block=flags_obj.use_se_block, use_sk_block=flags_obj.use_sk_block, no_downsample=flags_obj.no_downsample, anti_alias_filter_size=flags_obj.anti_alias_filter_size, anti_alias_type=flags_obj.anti_alias_type, bn_momentum=flags_obj.bn_momentum, embedding_size=flags_obj.embedding_size, pool_type=flags_obj.pool_type, bl_alpha=flags_obj.bl_alpha, bl_beta=flags_obj.bl_beta, dtype=flags_core.get_tf_dtype(flags_obj), loss_type=flags_obj.cls_loss_type) def train_and_evaluate(hooks): tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) if num_train_epochs: classifier.train(input_fn=lambda: input_fn_train(num_train_epochs), hooks=hooks, steps=flags_obj.max_train_steps) tf.logging.info('Starting to evaluate.') if zeroshot_eval: tf.reset_default_graph() eval_results = recall_metric.recall_at_k( flags_obj, flags_core, input_fns.input_fn_ir_eval, model, num_images['validation'], eval_similarity=flags_obj.eval_similarity, return_embedding=True) else: eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) return eval_results total_train_steps = flags_obj.train_epochs * int( num_images['train'] / flags_obj.batch_size) if flags_obj.eval_only or not flags_obj.train_epochs: # If --eval_only is set, perform a single loop with zero train epochs. schedule, n_loops = [0], 1 elif flags_obj.export_only: schedule, n_loops = [], 0 else: n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals) schedule = [ flags_obj.epochs_between_evals for _ in range(int(n_loops)) ] schedule[-1] = flags_obj.train_epochs - sum( schedule[:-1]) # over counting. schedule = config_utils.get_epoch_schedule(flags_obj, schedule, num_images) tf.logging.info('epoch schedule:') tf.logging.info(schedule) classifier, exp = gen_estimator() if flags_obj.pretrained_model_checkpoint_path: warm_start_hook = WarmStartHook( flags_obj.pretrained_model_checkpoint_path) train_hooks.append(warm_start_hook) for cycle_index, num_train_epochs in enumerate(schedule): eval_results = train_and_evaluate(train_hooks) if zeroshot_eval: metric = eval_results['recall_at_1'] else: metric = eval_results['accuracy'] ckpt_keeper.save(metric, flags_obj.model_dir) if exp: exp.metric("accuracy", metric) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold(flags_obj.stop_threshold, metric): break if model_helpers.past_stop_threshold(total_train_steps, eval_results['global_step']): break if exp: exp.end() if flags_obj.export_dir is not None: export_utils.export_pb(flags_core, flags_obj, shape, classifier)
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config(enable_xla=flags_obj.enable_xla) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) if tf.config.list_physical_devices('GPU'): if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj. datasets_num_private_threads) common.set_cudnn_batchnorm_mode() data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.config.list_physical_devices('GPU') else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribute_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) if flags_obj.steps_per_loop is None: steps_per_loop = per_epoch_steps elif flags_obj.steps_per_loop > per_epoch_steps: steps_per_loop = per_epoch_steps logging.warn('Setting steps_per_loop to %d to respect epoch boundary.', steps_per_loop) else: steps_per_loop = flags_obj.steps_per_loop logging.info( 'Training %d epochs, each epoch has %d steps, ' 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory( flags_obj.batch_size, flags_obj.log_steps, logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None) with distribute_utils.get_strategy_scope(strategy): runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, per_epoch_steps) eval_interval = flags_obj.epochs_between_evals * per_epoch_steps checkpoint_interval = (steps_per_loop * 5 if flags_obj.enable_checkpoint_and_export else None) summary_interval = steps_per_loop if flags_obj.enable_tensorboard else None checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=flags_obj.model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval) resnet_controller = orbit.Controller( strategy=strategy, trainer=runnable, evaluator=runnable if not flags_obj.skip_eval else None, global_step=runnable.global_step, steps_per_loop=steps_per_loop, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, summary_dir=flags_obj.model_dir, eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval')) time_callback.on_train_begin() if not flags_obj.skip_eval: resnet_controller.train_and_evaluate(train_steps=per_epoch_steps * train_epochs, eval_steps=eval_steps, eval_interval=eval_interval) else: resnet_controller.train(steps=per_epoch_steps * train_epochs) time_callback.on_train_end() stats = build_stats(runnable, time_callback) return stats
def resnet_main( flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. Returns: Dict of results of the run. """ model_helpers.apply_clean(flags.FLAGS) # Ensures flag override logic is only executed if explicitly triggered. if flags_obj.tf_gpu_thread_mode: override_flags_and_set_envars_for_gpu_thread_pool(flags_obj) # Creates session config. allow_soft_placement = True, is required for # multi-GPU and is not harmful for other modes. session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) # Creates a `RunConfig` that checkpoints every 24 hours which essentially # results in checkpoints determined only by `epochs_between_evals`. run_config = tf.estimator.RunConfig( train_distribute=distribution_strategy, session_config=session_config, save_checkpoints_secs=60*60*24) # Initializes model with all but the dense layer from pretrained ResNet. if flags_obj.pretrained_model_checkpoint_path is not None: warm_start_settings = tf.estimator.WarmStartSettings( flags_obj.pretrained_model_checkpoint_path, vars_to_warm_start='^(?!.*dense)') else: warm_start_settings = None classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, warm_start_from=warm_start_settings, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj), 'fine_tune': flags_obj.fine_tune }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, } if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('resnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks( flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) def input_fn_train(num_epochs): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, dtype=flags_core.get_tf_dtype(flags_obj), datasets_num_private_threads=flags_obj.datasets_num_private_threads, num_parallel_batches=flags_obj.datasets_num_parallel_batches) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_tf_dtype(flags_obj)) if flags_obj.eval_only or not flags_obj.train_epochs: # If --eval_only is set, perform a single loop with zero train epochs. schedule, n_loops = [0], 1 else: # Compute the number of times to loop while training. All but the last # pass will train for `epochs_between_evals` epochs, while the last will # train for the number needed to reach `training_epochs`. For instance if # train_epochs = 25 and epochs_between_evals = 10 # schedule will be set to [10, 10, 5]. That is to say, the loop will: # Train for 10 epochs and then evaluate. # Train for another 10 epochs and then evaluate. # Train for a final 5 epochs (to reach 25 epochs) and then evaluate. n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals) schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))] schedule[-1] = flags_obj.train_epochs - sum(schedule[:-1]) # over counting. for cycle_index, num_train_epochs in enumerate(schedule): tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) if num_train_epochs: classifier.train(input_fn=lambda: input_fn_train(num_train_epochs), hooks=train_hooks, max_steps=flags_obj.max_train_steps) tf.logging.info('Starting to evaluate.') # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, which # will iterate forever. Passing steps=flags_obj.max_train_steps allows the # eval (which is generally unimportant in those circumstances) to terminate. # Note that eval will run for max_train_steps each loop, regardless of the # global_step count. eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold( flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. export_dtype = flags_core.get_tf_dtype(flags_obj) if flags_obj.image_bytes_as_serving_input: input_receiver_fn = functools.partial( image_bytes_serving_input_fn, shape, dtype=export_dtype) else: input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size, dtype=export_dtype) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn, strip_default_attrs=True) return eval_results
def dtype(): return flags_core.get_tf_dtype(flags.FLAGS)
def run(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. """ if flags_obj.enable_eager: tf.enable_eager_execution() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError( 'dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.test.is_built_with_cuda() else 'channels_last') tf.keras.backend.set_image_data_format(data_format) # pylint: disable=protected-access if flags_obj.use_synthetic_data: input_fn = keras_common.get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj)) else: input_fn = imagenet_main.input_fn train_input_dataset = input_fn(is_training=True, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) eval_input_dataset = input_fn(is_training=False, data_dir=flags_obj.data_dir, batch_size=flags_obj.batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) strategy = distribution_utils.get_distribution_strategy( num_gpus=flags_obj.num_gpus, turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy ) strategy_scope = keras_common.get_strategy_scope(strategy) with strategy_scope: optimizer = keras_common.get_optimizer() model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['sparse_categorical_accuracy']) time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) validation_data = eval_input_dataset if flags_obj.skip_eval: # Only build the training graph. This reduces memory usage introduced by # control flow ops in layers that have different implementations for # training and inference (e.g., batch norm). tf.keras.backend.set_learning_phase(1) num_eval_steps = None validation_data = None history = model.fit( train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=[time_callback, lr_callback, tensorboard_callback], validation_steps=num_eval_steps, validation_data=validation_data, verbose=2) eval_output = None if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=1) stats = keras_common.build_stats(history, eval_output, time_callback) return stats
def convinh_main( flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for convinh Models. Args: flags_obj: An object containing parsed flags. See define_convinh_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj. _dir is passed. """ model_helpers.apply_clean(flags.FLAGS) # Using the Winograd non-fused algorithms provides a small performance boost. os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # Create session config based on values of inter_op_parallelism_threads and # intra_op_parallelism_threads. Note that we default to having # allow_soft_placement = True, which is required for multi-GPU and not # harmful for other modes. session_config = tf.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) run_config = tf.estimator.RunConfig( tf_random_seed=flags_obj.seed, train_distribute=distribution_strategy, session_config=session_config, keep_checkpoint_max = flags_obj.num_ckpt ) classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, params={ 'model_params':{ 'data_format':flags_obj.data_format, 'filters':list(map(int,flags_obj.filters)), 'ratio_PV': flags_obj.ratio_PV, 'ratio_SST': flags_obj.ratio_SST, 'conv_kernel_size':list(map(int,flags_obj.conv_kernel_size)), 'conv_kernel_size_inh':list(map(int,flags_obj.conv_kernel_size_inh)), 'conv_strides':list(map(int,flags_obj.conv_strides)), 'pool_size':list(map(int,flags_obj.pool_size)), 'pool_strides':list(map(int,flags_obj.pool_strides)), 'num_ff_layers':flags_obj.num_ff_layers, 'num_rnn_layers':flags_obj.num_rnn_layers, 'connection':flags_obj.connection, 'n_time':flags_obj.n_time, 'cell_fn':flags_obj.cell_fn, 'act_fn':flags_obj.act_fn, 'pvsst_circuit':flags_obj.pvsst_circuit, 'gating':flags_obj.gating, 'normalize':flags_obj.normalize, 'num_classes':flags_obj.num_classes }, 'batch_size' : flags_obj.batch_size, 'weight_decay': flags_obj.weight_decay, 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj) }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'convinh_size': flags_obj.convinh_size, # deprecated 'convinh_version': flags_obj.convinh_version, # deprecated 'synthetic_data': flags_obj.use_synthetic_data, # deprecated 'train_epochs': flags_obj.train_epochs, } if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('convinh', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks( flags_obj.hooks, model_dir=flags_obj.model_dir, batch_size=flags_obj.batch_size) class input_fn_train(object): def __init__(self,num_epochs): self._num_epochs = num_epochs def __call__(self): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=self._num_epochs, num_gpus=flags_core.get_num_gpus(flags_obj)) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1) tf.logging.info('Evaluate the intial model.') eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) # training total_training_cycle = (flags_obj.train_epochs // flags_obj.epochs_between_evals) + 1 for cycle_index in range(total_training_cycle): cur_train_epochs = flags_obj.epochs_between_evals if cycle_index else 1 tf.logging.info('Starting a training cycle: %d/%d, with %d epochs', cycle_index, total_training_cycle, cur_train_epochs) classifier.train(input_fn=input_fn_train(cur_train_epochs), hooks=train_hooks, max_steps=flags_obj.max_train_steps) tf.logging.info('Starting to evaluate.') # flags_obj.max_train_steps is generally associated with testing and # profiling. As a result it is frequently called with synthetic data, which # will iterate forever. Passing steps=flags_obj.max_train_steps allows the # eval (which is generally unimportant in those circumstances) to terminate. # Note that eval will run for max_train_steps each loop, regardless of the # global_step count. eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags_obj.max_train_steps) benchmark_logger.log_evaluation_result(eval_results) if model_helpers.past_stop_threshold( flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=1) if cycle_index==0: classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn, checkpoint_path='{}/model.ckpt-0'.format(flags_obj.model_dir)) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)