def train_and_eval( params: base_configs.ExperimentConfig, strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]: """Runs the train and eval path using compile/fit.""" logging.info('Running train and eval.') distribute_utils.configure_cluster(params.runtime.worker_hosts, params.runtime.task_index) # Note: for TPUs, strategy and scope should be created before the dataset strategy = strategy_override or distribute_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) strategy_scope = distribute_utils.get_strategy_scope(strategy) logging.info('Detected %d devices.', strategy.num_replicas_in_sync if strategy else 1) label_smoothing = params.model.loss.label_smoothing one_hot = label_smoothing and label_smoothing > 0 builders = _get_dataset_builders(params, strategy, one_hot) datasets = [ builder.build(strategy) if builder else None for builder in builders ] # Unpack datasets and builders based on train/val/test splits train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking train_dataset, validation_dataset = datasets train_epochs = params.train.epochs train_steps = params.train.steps or train_builder.num_steps validation_steps = params.evaluation.steps or validation_builder.num_steps initialize(params, train_builder) logging.info('Global batch size: %d', train_builder.global_batch_size) with strategy_scope: model_params = params.model.model_params.as_dict() model = get_models()[params.model.name](**model_params) learning_rate = optimizer_factory.build_learning_rate( params=params.model.learning_rate, batch_size=train_builder.global_batch_size, train_epochs=train_epochs, train_steps=train_steps) optimizer = optimizer_factory.build_optimizer( optimizer_name=params.model.optimizer.name, base_learning_rate=learning_rate, params=params.model.optimizer.as_dict(), model=model) optimizer = performance.configure_optimizer( optimizer, use_float16=train_builder.dtype == 'float16', loss_scale=get_loss_scale(params)) metrics_map = _get_metrics(one_hot) metrics = [metrics_map[metric] for metric in params.train.metrics] steps_per_loop = train_steps if params.train.set_epoch_loop else 1 if one_hot: loss_obj = tf.keras.losses.CategoricalCrossentropy( label_smoothing=params.model.loss.label_smoothing) else: loss_obj = tf.keras.losses.SparseCategoricalCrossentropy() model.compile(optimizer=optimizer, loss=loss_obj, metrics=metrics, steps_per_execution=steps_per_loop) initial_epoch = 0 if params.train.resume_checkpoint: initial_epoch = resume_from_checkpoint(model=model, model_dir=params.model_dir, train_steps=train_steps) callbacks = custom_callbacks.get_callbacks( model_checkpoint=params.train.callbacks. enable_checkpoint_and_export, include_tensorboard=params.train.callbacks.enable_tensorboard, time_history=params.train.callbacks.enable_time_history, track_lr=params.train.tensorboard.track_lr, write_model_weights=params.train.tensorboard.write_model_weights, initial_step=initial_epoch * train_steps, batch_size=train_builder.global_batch_size, log_steps=params.train.time_history.log_steps, model_dir=params.model_dir, backup_and_restore=params.train.callbacks.enable_backup_and_restore ) serialize_config(params=params, model_dir=params.model_dir) if params.evaluation.skip_eval: validation_kwargs = {} else: validation_kwargs = { 'validation_data': validation_dataset, 'validation_steps': validation_steps, 'validation_freq': params.evaluation.epochs_between_evals, } history = model.fit(train_dataset, epochs=train_epochs, steps_per_epoch=train_steps, initial_epoch=initial_epoch, callbacks=callbacks, verbose=2, **validation_kwargs) validation_output = None if not params.evaluation.skip_eval: validation_output = model.evaluate(validation_dataset, steps=validation_steps, verbose=2) # TODO(dankondratyuk): eval and save final test accuracy stats = common.build_stats(history, validation_output, callbacks) return stats
def run(flags_obj, datasets_override=None, strategy_override=None): """Run MNIST model training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. datasets_override: A pair of `tf.data.Dataset` objects to train the model, representing the train and test sets. strategy_override: A `tf.distribute.Strategy` object to use for model. Returns: Dictionary of training and eval stats. """ # Start TF profiler server. tf.profiler.experimental.server.start(flags_obj.profiler_port) strategy = strategy_override or distribute_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, tpu_address=flags_obj.tpu) strategy_scope = distribute_utils.get_strategy_scope(strategy) mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir) if flags_obj.download: mnist.download_and_prepare() mnist_train, mnist_test = datasets_override or mnist.as_dataset( split=['train', 'test'], decoders={'image': decode_image()}, # pylint: disable=no-value-for-parameter as_supervised=True) train_input_dataset = mnist_train.cache().repeat().shuffle( buffer_size=50000).batch(flags_obj.batch_size) eval_input_dataset = mnist_test.cache().repeat().batch( flags_obj.batch_size) with strategy_scope: lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( 0.05, decay_steps=100000, decay_rate=0.96) optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule) model = build_model() model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) num_train_examples = mnist.info.splits['train'].num_examples train_steps = num_train_examples // flags_obj.batch_size train_epochs = flags_obj.train_epochs ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}') callbacks = [ tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True), tf.keras.callbacks.TensorBoard(log_dir=flags_obj.model_dir), ] num_eval_examples = mnist.info.splits['test'].num_examples num_eval_steps = num_eval_examples // flags_obj.batch_size history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=callbacks, validation_steps=num_eval_steps, validation_data=eval_input_dataset, validation_freq=flags_obj.epochs_between_evals) export_path = os.path.join(flags_obj.model_dir, 'saved_model') model.save(export_path, include_optimizer=False) eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) stats = common.build_stats(history, eval_output, callbacks) return stats