def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') crop_shape = (FLAGS.cifar100_crop_size, FLAGS.cifar100_crop_size, 3) cifar_train, cifar_test = cifar100_dataset.get_centralized_cifar100( train_batch_size=FLAGS.batch_size, crop_shape=crop_shape) optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')() model = resnet_models.create_resnet18(input_shape=crop_shape, num_classes=NUM_CLASSES) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), optimizer=optimizer, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) hparams_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) centralized_training_loop.run(keras_model=model, train_dataset=cifar_train, validation_dataset=cifar_test, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, num_epochs=FLAGS.num_epochs, hparams_dict=hparams_dict, decay_epochs=FLAGS.decay_epochs, lr_decay=FLAGS.lr_decay)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') train_dataset, eval_dataset = emnist_dataset.get_centralized_emnist_datasets( batch_size=FLAGS.batch_size, only_digits=False) optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')() if FLAGS.model == 'cnn': model = emnist_models.create_conv_dropout_model(only_digits=False) elif FLAGS.model == '2nn': model = emnist_models.create_two_hidden_layer_model(only_digits=False) else: raise ValueError('Cannot handle model flag [{!s}].'.format( FLAGS.model)) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), optimizer=optimizer, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) hparams_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) centralized_training_loop.run(keras_model=model, train_dataset=train_dataset, validation_dataset=eval_dataset, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, num_epochs=FLAGS.num_epochs, hparams_dict=hparams_dict, decay_epochs=FLAGS.decay_epochs, lr_decay=FLAGS.lr_decay)
def test_metric_writing_without_validation(self): keras_model = compiled_keras_model() dataset = create_dataset() exp_name = 'write_metrics' temp_filepath = self.get_temp_dir() root_output_dir = temp_filepath centralized_training_loop.run(keras_model=keras_model, train_dataset=dataset, experiment_name=exp_name, root_output_dir=root_output_dir, num_epochs=3) self.assertTrue(tf.io.gfile.exists(root_output_dir)) log_dir = os.path.join(root_output_dir, 'logdir', exp_name) train_log_dir = os.path.join(log_dir, 'train') validation_log_dir = os.path.join(log_dir, 'validation') self.assertTrue(tf.io.gfile.exists(log_dir)) self.assertTrue(tf.io.gfile.exists(train_log_dir)) self.assertFalse(tf.io.gfile.exists(validation_log_dir)) results_dir = os.path.join(root_output_dir, 'results', exp_name) self.assertTrue(tf.io.gfile.exists(results_dir)) metrics_file = os.path.join(results_dir, 'metric_results.csv') self.assertTrue(tf.io.gfile.exists(metrics_file)) hparams_file = os.path.join(results_dir, 'hparams.csv') self.assertFalse(tf.io.gfile.exists(hparams_file)) metrics_csv = pd.read_csv(metrics_file) self.assertEqual(metrics_csv.shape, (3, 3)) self.assertCountEqual(metrics_csv.columns, ['Unnamed: 0', 'loss', 'mean_squared_error'])
def test_hparam_writing(self): keras_model = compiled_keras_model() dataset = create_dataset() exp_name = 'write_hparams' temp_filepath = self.get_temp_dir() root_output_dir = temp_filepath hparams_dict = { 'param1': 0, 'param2': 5.02, 'param3': 'sample', 'param4': True } centralized_training_loop.run(keras_model=keras_model, train_dataset=dataset, experiment_name=exp_name, root_output_dir=root_output_dir, num_epochs=1, hparams_dict=hparams_dict) self.assertTrue(tf.io.gfile.exists(root_output_dir)) results_dir = os.path.join(root_output_dir, 'results', exp_name) self.assertTrue(tf.io.gfile.exists(results_dir)) hparams_file = os.path.join(results_dir, 'hparams.csv') self.assertTrue(tf.io.gfile.exists(hparams_file)) hparams_csv = pd.read_csv(hparams_file, index_col=0) expected_csv = pd.DataFrame(hparams_dict, index=[0]) pd.testing.assert_frame_equal(hparams_csv, expected_csv)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') train_dataset, eval_dataset = emnist_ae_dataset.get_centralized_emnist_datasets( batch_size=FLAGS.batch_size, only_digits=False) optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')() model = emnist_ae_models.create_autoencoder_model() model.compile(loss=tf.keras.losses.MeanSquaredError(), optimizer=optimizer, metrics=[tf.keras.metrics.MeanSquaredError()]) hparams_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) centralized_training_loop.run(keras_model=model, train_dataset=train_dataset, validation_dataset=eval_dataset, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, num_epochs=FLAGS.num_epochs, hparams_dict=hparams_dict, decay_epochs=FLAGS.decay_epochs, lr_decay=FLAGS.lr_decay)
def run_centralized(optimizer: tf.keras.optimizers.Optimizer, experiment_name: str, root_output_dir: str, num_epochs: int, batch_size: int, decay_epochs: Optional[int] = None, lr_decay: Optional[float] = None, hparams_dict: Optional[Mapping[str, Any]] = None, sequence_length: Optional[int] = 80, max_batches: Optional[int] = None): """Trains a two-layer RNN on Shakespeare next-character-prediction. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). num_epochs: The number of training epochs. batch_size: The batch size, used for train, validation, and test. decay_epochs: The number of epochs of training before decaying the learning rate. If None, no decay occurs. lr_decay: The amount to decay the learning rate by after `decay_epochs` training epochs have occurred. hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. sequence_length: The sequence length used for Shakespeare preprocessing. max_batches: If set to a positive integer, datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. """ train_dataset, eval_dataset = shakespeare_dataset.get_centralized_datasets( train_batch_size=batch_size, max_train_batches=max_batches, max_test_batches=max_batches, sequence_length=sequence_length) pad_token, _, _, _ = shakespeare_dataset.get_special_tokens() model = shakespeare_models.create_recurrent_model( vocab_size=VOCAB_SIZE, sequence_length=sequence_length) model.compile( optimizer=optimizer, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[ keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token]) ]) centralized_training_loop.run(keras_model=model, train_dataset=train_dataset, validation_dataset=eval_dataset, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict, decay_epochs=decay_epochs, lr_decay=lr_decay)
def run_centralized(optimizer: tf.keras.optimizers.Optimizer, experiment_name: str, root_output_dir: str, num_epochs: int, batch_size: int, decay_epochs: Optional[int] = None, lr_decay: Optional[float] = None, hparams_dict: Optional[Mapping[str, Any]] = None, emnist_model: Optional[str] = 'cnn'): """Trains a model on EMNIST character recognition using a given optimizer. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). num_epochs: The number of training epochs. batch_size: The batch size, used for train, validation, and test. decay_epochs: The number of epochs of training before decaying the learning rate. If None, no decay occurs. lr_decay: The amount to decay the learning rate by after `decay_epochs` training epochs have occurred. hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. emnist_model: A string specifying the model used for character recognition. Can be one of `cnn` and `2nn`, corresponding to a CNN model and a densely connected 2-layer model (respectively). """ train_dataset, eval_dataset = emnist_dataset.get_centralized_emnist_datasets( batch_size=batch_size, only_digits=False) if emnist_model == 'cnn': model = emnist_models.create_conv_dropout_model(only_digits=False) elif emnist_model == '2nn': model = emnist_models.create_two_hidden_layer_model(only_digits=False) else: raise ValueError( 'Cannot handle model flag [{!s}].'.format(emnist_model)) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), optimizer=optimizer, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) centralized_training_loop.run(keras_model=model, train_dataset=train_dataset, validation_dataset=eval_dataset, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict, decay_epochs=decay_epochs, lr_decay=lr_decay)
def run_centralized(optimizer: tf.keras.optimizers.Optimizer, experiment_name: str, root_output_dir: str, num_epochs: int, batch_size: int, decay_epochs: Optional[int] = None, lr_decay: Optional[float] = None, hparams_dict: Optional[Mapping[str, Any]] = None, crop_size: Optional[int] = 24, max_batches: Optional[int] = None): """Trains a ResNet-18 on CIFAR-100 using a given optimizer. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). num_epochs: The number of training epochs. batch_size: The batch size, used for train, validation, and test. decay_epochs: The number of epochs of training before decaying the learning rate. If None, no decay occurs. lr_decay: The amount to decay the learning rate by after `decay_epochs` training epochs have occurred. hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. crop_size: The crop size used for CIFAR-100 preprocessing. max_batches: If set to a positive integer, datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. """ crop_shape = (crop_size, crop_size, NUM_CHANNELS) cifar_train, cifar_test = cifar100_dataset.get_centralized_datasets( train_batch_size=batch_size, max_train_batches=max_batches, max_test_batches=max_batches, crop_shape=crop_shape) model = resnet_models.create_resnet18(input_shape=crop_shape, num_classes=NUM_CLASSES) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), optimizer=optimizer, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) centralized_training_loop.run(keras_model=model, train_dataset=cifar_train, validation_dataset=cifar_test, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict, decay_epochs=decay_epochs, lr_decay=lr_decay)
def run_centralized(optimizer: tf.keras.optimizers.Optimizer, experiment_name: str, root_output_dir: str, num_epochs: int, batch_size: int, decay_epochs: Optional[int] = None, lr_decay: Optional[float] = None, hparams_dict: Optional[Mapping[str, Any]] = None, max_batches: Optional[int] = None): """Trains a bottleneck autoencoder on EMNIST using a given optimizer. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). num_epochs: The number of training epochs. batch_size: The batch size, used for train, validation, and test. decay_epochs: The number of epochs of training before decaying the learning rate. If None, no decay occurs. lr_decay: The amount to decay the learning rate by after `decay_epochs` training epochs have occurred. hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. max_batches: If set to a positive integer, datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. """ train_dataset, eval_dataset = emnist_ae_dataset.get_centralized_datasets( train_batch_size=batch_size, max_train_batches=max_batches, max_test_batches=max_batches, only_digits=False) model = emnist_ae_models.create_autoencoder_model() model.compile(loss=tf.keras.losses.MeanSquaredError(), optimizer=optimizer, metrics=[tf.keras.metrics.MeanSquaredError()]) centralized_training_loop.run(keras_model=model, train_dataset=train_dataset, validation_dataset=eval_dataset, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict, decay_epochs=decay_epochs, lr_decay=lr_decay)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') train_client_data, test_client_data = ( tff.simulation.datasets.shakespeare.load_data()) def preprocess(ds): return shakespeare_dataset.convert_snippets_to_character_sequence_examples( dataset=ds, batch_size=FLAGS.batch_size, epochs=1, shuffle_buffer_size=0, sequence_length=FLAGS.shakespeare_sequence_length) train_dataset = train_client_data.create_tf_dataset_from_all_clients() if FLAGS.shuffle_train_data: train_dataset = train_dataset.shuffle(buffer_size=10000) train_dataset = preprocess(train_dataset) eval_dataset = preprocess( test_client_data.create_tf_dataset_from_all_clients()) optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')() pad_token, _, _, _ = shakespeare_dataset.get_special_tokens() model = shakespeare_models.create_recurrent_model( vocab_size=VOCAB_SIZE, sequence_length=FLAGS.shakespeare_sequence_length) model.compile( optimizer=optimizer, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[ keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token]) ]) hparams_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) centralized_training_loop.run(keras_model=model, train_dataset=train_dataset, validation_dataset=eval_dataset, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, num_epochs=FLAGS.num_epochs, hparams_dict=hparams_dict, decay_epochs=FLAGS.decay_epochs, lr_decay=FLAGS.lr_decay)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') train_dataset, validation_dataset, test_dataset = stackoverflow_lr_dataset.get_centralized_stackoverflow_datasets( batch_size=FLAGS.batch_size, vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size, vocab_tags_size=FLAGS.so_lr_vocab_tags_size, shuffle_buffer_size=FLAGS.shuffle_buffer_size, num_validation_examples=FLAGS.so_lr_num_validation_examples) optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')() model = stackoverflow_lr_models.create_logistic_model( vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size, vocab_tags_size=FLAGS.so_lr_vocab_tags_size) model.compile(loss=tf.keras.losses.BinaryCrossentropy( from_logits=False, reduction=tf.keras.losses.Reduction.SUM), optimizer=optimizer, metrics=[ tf.keras.metrics.Precision(), tf.keras.metrics.Recall(top_k=5) ]) hparams_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) centralized_training_loop.run(keras_model=model, train_dataset=train_dataset, validation_dataset=validation_dataset, test_dataset=test_dataset, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, num_epochs=FLAGS.num_epochs, hparams_dict=hparams_dict, decay_epochs=FLAGS.decay_epochs, lr_decay=FLAGS.lr_decay)
def test_lr_callback(self): optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) keras_model = compiled_keras_model(optimizer=optimizer) dataset = create_dataset() history = centralized_training_loop.run( keras_model=keras_model, train_dataset=dataset, experiment_name='test_experiment', root_output_dir=self.get_temp_dir(), num_epochs=10, decay_epochs=8, lr_decay=0.5, validation_dataset=dataset) self.assertCountEqual(history.history.keys(), [ 'loss', 'mean_squared_error', 'val_loss', 'val_mean_squared_error', 'lr' ]) self.assertAllClose(history.history['lr'], [0.1] * 7 + [0.05] * 3)
def test_training_reduces_loss(self): keras_model = compiled_keras_model() dataset = create_dataset() history = centralized_training_loop.run( keras_model=keras_model, train_dataset=dataset, experiment_name='test_experiment', root_output_dir=self.get_temp_dir(), num_epochs=5, validation_dataset=dataset) self.assertCountEqual(history.history.keys(), [ 'loss', 'mean_squared_error', 'val_loss', 'val_mean_squared_error' ]) self.assertMetricDecreases(history.history['loss'], expected_len=5) self.assertMetricDecreases(history.history['val_loss'], expected_len=5) self.assertMetricDecreases(history.history['mean_squared_error'], expected_len=5) self.assertMetricDecreases(history.history['val_mean_squared_error'], expected_len=5)
def run_centralized(optimizer: tf.keras.optimizers.Optimizer, experiment_name: str, root_output_dir: str, num_epochs: int, batch_size: int, decay_epochs: Optional[int] = None, lr_decay: Optional[float] = None, hparams_dict: Optional[Mapping[str, Any]] = None, vocab_size: Optional[int] = 10000, num_oov_buckets: Optional[int] = 1, sequence_length: Optional[int] = 20, num_validation_examples: Optional[int] = 10000, embedding_size: Optional[int] = 96, latent_size: Optional[int] = 670, num_layers: Optional[int] = 1, shared_embedding: Optional[bool] = False, max_batches: Optional[int] = None): """Trains an RNN on the Stack Overflow next word prediction task. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). num_epochs: The number of training epochs. batch_size: The batch size, used for train, validation, and test. decay_epochs: The number of epochs of training before decaying the learning rate. If None, no decay occurs. lr_decay: The amount to decay the learning rate by after `decay_epochs` training epochs have occurred. hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. vocab_size: Integer dictating the number of most frequent words to use in the vocabulary. num_oov_buckets: The number of out-of-vocabulary buckets to use. sequence_length: The maximum number of words to take for each sequence. num_validation_examples: The number of test examples to use for validation. embedding_size: The dimension of the word embedding layer. latent_size: The dimension of the latent units in the recurrent layers. num_layers: The number of stacked recurrent layers to use. shared_embedding: Boolean indicating whether to tie input and output embeddings. max_batches: If set to a positive integer, datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. """ train_dataset, validation_dataset, test_dataset = stackoverflow_dataset.get_centralized_datasets( vocab_size=vocab_size, max_seq_len=sequence_length, train_batch_size=batch_size, max_train_batches=max_batches, max_validation_batches=max_batches, max_test_batches=max_batches, num_validation_examples=num_validation_examples, num_oov_buckets=num_oov_buckets, ) model = stackoverflow_models.create_recurrent_model( vocab_size=vocab_size, num_oov_buckets=num_oov_buckets, name='stackoverflow-lstm', embedding_size=embedding_size, latent_size=latent_size, num_layers=num_layers, shared_embedding=shared_embedding) special_tokens = stackoverflow_dataset.get_special_tokens( vocab_size=vocab_size, num_oov_buckets=num_oov_buckets) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=optimizer, metrics=[ keras_metrics.MaskedCategoricalAccuracy( name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), ]) centralized_training_loop.run( keras_model=model, train_dataset=train_dataset, validation_dataset=validation_dataset, test_dataset=test_dataset, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict, decay_epochs=decay_epochs, lr_decay=lr_decay)
def run_centralized(optimizer: tf.keras.optimizers.Optimizer, experiment_name: str, root_output_dir: str, num_epochs: int, batch_size: int, decay_epochs: Optional[int] = None, lr_decay: Optional[float] = None, hparams_dict: Optional[Mapping[str, Any]] = None, vocab_tokens_size: Optional[int] = 10000, vocab_tags_size: Optional[int] = 500, num_validation_examples: Optional[int] = 10000, max_batches: Optional[int] = None): """Trains an RNN on the Stack Overflow next word prediction task. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). num_epochs: The number of training epochs. batch_size: The batch size, used for train, validation, and test. decay_epochs: The number of epochs of training before decaying the learning rate. If None, no decay occurs. lr_decay: The amount to decay the learning rate by after `decay_epochs` training epochs have occurred. hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. vocab_tokens_size: Integer dictating the number of most frequent words to use in the vocabulary. vocab_tags_size: Integer dictating the number of most frequent tags to use in the label creation. num_validation_examples: The number of test examples to use for validation. max_batches: If set to a positive integer, datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. """ train_dataset, validation_dataset, test_dataset = stackoverflow_lr_dataset.get_centralized_datasets( train_batch_size=batch_size, max_train_batches=max_batches, max_validation_batches=max_batches, max_test_batches=max_batches, vocab_tokens_size=vocab_tokens_size, vocab_tags_size=vocab_tags_size, num_validation_examples=num_validation_examples) model = stackoverflow_lr_models.create_logistic_model( vocab_tokens_size=vocab_tokens_size, vocab_tags_size=vocab_tags_size) model.compile(loss=tf.keras.losses.BinaryCrossentropy( from_logits=False, reduction=tf.keras.losses.Reduction.SUM), optimizer=optimizer, metrics=[ tf.keras.metrics.Precision(), tf.keras.metrics.Recall(top_k=5) ]) centralized_training_loop.run(keras_model=model, train_dataset=train_dataset, validation_dataset=validation_dataset, test_dataset=test_dataset, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict, decay_epochs=decay_epochs, lr_decay=lr_decay)