def test_hparam_writing(self): keras_model = compiled_keras_model() dataset = create_dataset() exp_name = 'write_hparams' temp_filepath = self.get_temp_dir() root_output_dir = temp_filepath hparams_dict = { 'param1': 0, 'param2': 5.02, 'param3': 'sample', 'param4': True } centralized_training_loop.run( keras_model=keras_model, train_dataset=dataset, experiment_name=exp_name, root_output_dir=root_output_dir, num_epochs=1, hparams_dict=hparams_dict) self.assertTrue(tf.io.gfile.exists(root_output_dir)) results_dir = os.path.join(root_output_dir, 'results', exp_name) self.assertTrue(tf.io.gfile.exists(results_dir)) hparams_file = os.path.join(results_dir, 'hparams.csv') self.assertTrue(tf.io.gfile.exists(hparams_file)) hparams_csv = pd.read_csv(hparams_file, index_col=0) expected_csv = pd.DataFrame(hparams_dict, index=[0]) pd.testing.assert_frame_equal(hparams_csv, expected_csv)
def test_metric_writing_without_validation(self): keras_model = compiled_keras_model() dataset = create_dataset() exp_name = 'write_metrics' temp_filepath = self.get_temp_dir() root_output_dir = temp_filepath centralized_training_loop.run( keras_model=keras_model, train_dataset=dataset, experiment_name=exp_name, root_output_dir=root_output_dir, num_epochs=3) self.assertTrue(tf.io.gfile.exists(root_output_dir)) log_dir = os.path.join(root_output_dir, 'logdir', exp_name) train_log_dir = os.path.join(log_dir, 'train') validation_log_dir = os.path.join(log_dir, 'validation') self.assertTrue(tf.io.gfile.exists(log_dir)) self.assertTrue(tf.io.gfile.exists(train_log_dir)) self.assertFalse(tf.io.gfile.exists(validation_log_dir)) results_dir = os.path.join(root_output_dir, 'results', exp_name) self.assertTrue(tf.io.gfile.exists(results_dir)) metrics_file = os.path.join(results_dir, 'metric_results.csv') self.assertTrue(tf.io.gfile.exists(metrics_file)) hparams_file = os.path.join(results_dir, 'hparams.csv') self.assertFalse(tf.io.gfile.exists(hparams_file)) metrics_csv = pd.read_csv(metrics_file) self.assertEqual(metrics_csv.shape, (3, 3)) self.assertCountEqual(metrics_csv.columns, ['Unnamed: 0', 'loss', 'mean_squared_error'])
def run_centralized(optimizer: tf.keras.optimizers.Optimizer, experiment_name: str, root_output_dir: str, num_epochs: int, batch_size: int, decay_epochs: Optional[int] = None, lr_decay: Optional[float] = None, hparams_dict: Optional[Mapping[str, Any]] = None, sequence_length: Optional[int] = 80, max_batches: Optional[int] = None): """Trains a two-layer RNN on Shakespeare next-character-prediction. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). num_epochs: The number of training epochs. batch_size: The batch size, used for train, validation, and test. decay_epochs: The number of epochs of training before decaying the learning rate. If None, no decay occurs. lr_decay: The amount to decay the learning rate by after `decay_epochs` training epochs have occurred. hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. sequence_length: The sequence length used for Shakespeare preprocessing. max_batches: If set to a positive integer, datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. """ shakespeare_train, shakespeare_test = shakespeare_dataset.get_centralized_datasets( train_batch_size=batch_size, sequence_length=sequence_length) if max_batches and max_batches >= 1: shakespeare_train = shakespeare_train.take(max_batches) shakespeare_test = shakespeare_test.take(max_batches) pad_token, _, _, _ = shakespeare_dataset.get_special_tokens() model = shakespeare_models.create_recurrent_model( vocab_size=VOCAB_SIZE, sequence_length=sequence_length) model.compile( optimizer=optimizer, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[ keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token]) ]) centralized_training_loop.run(keras_model=model, train_dataset=shakespeare_train, validation_dataset=shakespeare_test, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict, decay_epochs=decay_epochs, lr_decay=lr_decay)
def run_centralized(optimizer: tf.keras.optimizers.Optimizer, experiment_name: str, root_output_dir: str, num_epochs: int, batch_size: int, decay_epochs: Optional[int] = None, lr_decay: Optional[float] = None, hparams_dict: Optional[Mapping[str, Any]] = None, crop_size: Optional[int] = 24, max_batches: Optional[int] = None, cache_dir: Optional[str] = '~'): """Trains a ResNet-18 on CIFAR-100 using a given optimizer. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). num_epochs: The number of training epochs. batch_size: The batch size, used for train, validation, and test. decay_epochs: The number of epochs of training before decaying the learning rate. If None, no decay occurs. lr_decay: The amount to decay the learning rate by after `decay_epochs` training epochs have occurred. hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. crop_size: The crop size used for CIFAR-100 preprocessing. max_batches: If set to a positive integer, datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. """ crop_shape = (crop_size, crop_size, NUM_CHANNELS) cifar_train, cifar_test = cifar100_dataset.get_centralized_datasets( train_batch_size=batch_size, crop_shape=crop_shape, cache_dir=cache_dir) if max_batches and max_batches >= 1: cifar_train = cifar_train.take(max_batches) cifar_test = cifar_test.take(max_batches) model = resnet_models.create_resnet18(input_shape=crop_shape, num_classes=NUM_CLASSES) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), optimizer=optimizer, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) centralized_training_loop.run(keras_model=model, train_dataset=cifar_train, validation_dataset=cifar_test, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict, decay_epochs=decay_epochs, lr_decay=lr_decay)
def run_centralized(optimizer: tf.keras.optimizers.Optimizer, experiment_name: str, root_output_dir: str, num_epochs: int, batch_size: int, decay_epochs: Optional[int] = None, lr_decay: Optional[float] = None, hparams_dict: Optional[Mapping[str, Any]] = None, max_batches: Optional[int] = None, cache_dir: Optional[str] = None): """Trains a bottleneck autoencoder on EMNIST using a given optimizer. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). num_epochs: The number of training epochs. batch_size: The batch size, used for train, validation, and test. decay_epochs: The number of epochs of training before decaying the learning rate. If None, no decay occurs. lr_decay: The amount to decay the learning rate by after `decay_epochs` training epochs have occurred. hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. max_batches: If set to a positive integer, datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. """ train_dataset, eval_dataset = emnist_dataset.get_centralized_datasets( train_batch_size=batch_size, only_digits=False, emnist_task='autoencoder', cache_dir=cache_dir) if max_batches and max_batches >= 1: train_dataset = train_dataset.take(max_batches) eval_dataset = eval_dataset.take(max_batches) model = emnist_ae_models.create_autoencoder_model() model.compile( loss=tf.keras.losses.MeanSquaredError(), optimizer=optimizer, metrics=[tf.keras.metrics.MeanSquaredError()]) centralized_training_loop.run( keras_model=model, train_dataset=train_dataset, validation_dataset=eval_dataset, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict, decay_epochs=decay_epochs, lr_decay=lr_decay)
def test_lr_callback(self): optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) keras_model = compiled_keras_model(optimizer=optimizer) dataset = create_dataset() history = centralized_training_loop.run( keras_model=keras_model, train_dataset=dataset, experiment_name='test_experiment', root_output_dir=self.get_temp_dir(), num_epochs=10, decay_epochs=8, lr_decay=0.5, validation_dataset=dataset) self.assertCountEqual(history.history.keys(), [ 'loss', 'mean_squared_error', 'val_loss', 'val_mean_squared_error', 'lr' ]) self.assertAllClose(history.history['lr'], [0.1] * 8 + [0.05] * 2)
def test_training_reduces_loss(self): keras_model = compiled_keras_model() dataset = create_dataset() history = centralized_training_loop.run( keras_model=keras_model, train_dataset=dataset, experiment_name='test_experiment', root_output_dir=self.get_temp_dir(), num_epochs=5, validation_dataset=dataset) self.assertCountEqual( history.history.keys(), ['loss', 'mean_squared_error', 'val_loss', 'val_mean_squared_error']) self.assertMetricDecreases(history.history['loss'], expected_len=5) self.assertMetricDecreases(history.history['val_loss'], expected_len=5) self.assertMetricDecreases( history.history['mean_squared_error'], expected_len=5) self.assertMetricDecreases( history.history['val_mean_squared_error'], expected_len=5)
def run_centralized(optimizer: tf.keras.optimizers.Optimizer, num_epochs: int, batch_size: int, decay_epochs: Optional[int] = None, lr_decay: Optional[float] = None, vocab_size: int = 10000, num_oov_buckets: int = 1, d_embed: int = 96, d_model: int = 512, d_hidden: int = 2048, num_heads: int = 8, num_layers: int = 1, max_position_encoding: int = 1000, dropout: float = 0.1, num_validation_examples: int = 10000, sequence_length: int = 20, experiment_name: str = 'centralized_stackoverflow', root_output_dir: str = '/tmp/fedopt_guide', hparams_dict: Optional[Mapping[str, Any]] = None, max_batches: Optional[int] = None): """Trains an Transformer on the Stack Overflow next word prediction task. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. num_epochs: The number of training epochs. batch_size: The batch size, used for train, validation, and test. decay_epochs: The number of epochs of training before decaying the learning rate. If None, no decay occurs. lr_decay: The amount to decay the learning rate by after `decay_epochs` training epochs have occurred. vocab_size: Vocab size for normal tokens. num_oov_buckets: Number of out of vocabulary buckets. d_embed: Dimension of the token embeddings. d_model: Dimension of features of MultiHeadAttention layers. d_hidden: Dimension of hidden layers of the FFN. num_heads: Number of attention heads. num_layers: Number of Transformer blocks. max_position_encoding: Maximum number of positions for position embeddings. dropout: Dropout rate. num_validation_examples: The number of test examples to use for validation. sequence_length: The maximum number of words to take for each sequence. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. max_batches: If set to a positive integer, datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. """ train_dataset, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets( vocab_size, sequence_length, train_batch_size=batch_size, num_validation_examples=num_validation_examples, num_oov_buckets=num_oov_buckets, ) if max_batches and max_batches >= 1: train_dataset = train_dataset.take(max_batches) validation_dataset = validation_dataset.take(max_batches) test_dataset = test_dataset.take(max_batches) model = transformer_models.create_transformer_lm( vocab_size=vocab_size, num_oov_buckets=num_oov_buckets, d_embed=d_embed, d_model=d_model, d_hidden=d_hidden, num_heads=num_heads, num_layers=num_layers, max_position_encoding=max_position_encoding, dropout=dropout, name='stackoverflow-transformer') special_tokens = stackoverflow_word_prediction.get_special_tokens( vocab_size=vocab_size, num_oov_buckets=num_oov_buckets) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=optimizer, metrics=[ keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), ]) centralized_training_loop.run(keras_model=model, train_dataset=train_dataset, validation_dataset=validation_dataset, test_dataset=test_dataset, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict, decay_epochs=decay_epochs, lr_decay=lr_decay)
def run_centralized(optimizer: tf.keras.optimizers.Optimizer, experiment_name: str, root_output_dir: str, num_epochs: int, batch_size: int, decay_epochs: Optional[int] = None, lr_decay: Optional[float] = None, hparams_dict: Optional[Mapping[str, Any]] = None, vocab_size: Optional[int] = 10000, num_oov_buckets: Optional[int] = 1, sequence_length: Optional[int] = 20, num_validation_examples: Optional[int] = 10000, embedding_size: Optional[int] = 96, latent_size: Optional[int] = 670, num_layers: Optional[int] = 1, shared_embedding: Optional[bool] = False, max_batches: Optional[int] = None, cache_dir: Optional[str] = None): """Trains an RNN on the Stack Overflow next word prediction task. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). num_epochs: The number of training epochs. batch_size: The batch size, used for train, validation, and test. decay_epochs: The number of epochs of training before decaying the learning rate. If None, no decay occurs. lr_decay: The amount to decay the learning rate by after `decay_epochs` training epochs have occurred. hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. vocab_size: Integer dictating the number of most frequent words to use in the vocabulary. num_oov_buckets: The number of out-of-vocabulary buckets to use. sequence_length: The maximum number of words to take for each sequence. num_validation_examples: The number of test examples to use for validation. embedding_size: The dimension of the word embedding layer. latent_size: The dimension of the latent units in the recurrent layers. num_layers: The number of stacked recurrent layers to use. shared_embedding: Boolean indicating whether to tie input and output embeddings. max_batches: If set to a positive integer, datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. """ train_dataset, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets( vocab_size=vocab_size, max_sequence_length=sequence_length, train_batch_size=batch_size, num_validation_examples=num_validation_examples, num_oov_buckets=num_oov_buckets, cache_dir=cache_dir) if max_batches and max_batches >= 1: train_dataset = train_dataset.take(max_batches) validation_dataset = validation_dataset.take(max_batches) test_dataset = test_dataset.take(max_batches) model = stackoverflow_models.create_recurrent_model( vocab_size=vocab_size, num_oov_buckets=num_oov_buckets, name='stackoverflow-lstm', embedding_size=embedding_size, latent_size=latent_size, num_layers=num_layers, shared_embedding=shared_embedding) special_tokens = stackoverflow_word_prediction.get_special_tokens( vocab_size=vocab_size, num_oov_buckets=num_oov_buckets) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=optimizer, metrics=[ keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), ]) centralized_training_loop.run(keras_model=model, train_dataset=train_dataset, validation_dataset=validation_dataset, test_dataset=test_dataset, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict, decay_epochs=decay_epochs, lr_decay=lr_decay)
def run_centralized(optimizer: tf.keras.optimizers.Optimizer, experiment_name: str, root_output_dir: str, num_epochs: int, batch_size: int, decay_epochs: Optional[int] = None, lr_decay: Optional[float] = None, hparams_dict: Optional[Mapping[str, Any]] = None, vocab_tokens_size: Optional[int] = 10000, vocab_tags_size: Optional[int] = 500, num_validation_examples: Optional[int] = 10000, max_batches: Optional[int] = None): """Trains an RNN on the Stack Overflow next word prediction task. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). num_epochs: The number of training epochs. batch_size: The batch size, used for train, validation, and test. decay_epochs: The number of epochs of training before decaying the learning rate. If None, no decay occurs. lr_decay: The amount to decay the learning rate by after `decay_epochs` training epochs have occurred. hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. vocab_tokens_size: Integer dictating the number of most frequent words to use in the vocabulary. vocab_tags_size: Integer dictating the number of most frequent tags to use in the label creation. num_validation_examples: The number of test examples to use for validation. max_batches: If set to a positive integer, datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. """ train_dataset, validation_dataset, test_dataset = stackoverflow_tag_prediction.get_centralized_datasets( train_batch_size=batch_size, word_vocab_size=vocab_tokens_size, tag_vocab_size=vocab_tags_size, num_validation_examples=num_validation_examples) if max_batches and max_batches >= 1: train_dataset = train_dataset.take(max_batches) validation_dataset = validation_dataset.take(max_batches) test_dataset = test_dataset.take(max_batches) model = stackoverflow_lr_models.create_logistic_model( vocab_tokens_size=vocab_tokens_size, vocab_tags_size=vocab_tags_size) model.compile(loss=tf.keras.losses.BinaryCrossentropy( from_logits=False, reduction=tf.keras.losses.Reduction.SUM), optimizer=optimizer, metrics=[ tf.keras.metrics.Precision(), tf.keras.metrics.Recall(top_k=5) ]) centralized_training_loop.run(keras_model=model, train_dataset=train_dataset, validation_dataset=validation_dataset, test_dataset=test_dataset, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict, decay_epochs=decay_epochs, lr_decay=lr_decay)
def run_centralized(optimizer: tf.keras.optimizers.Optimizer, experiment_name: str, root_output_dir: str, num_epochs: int, batch_size: int, decay_epochs: Optional[int] = None, lr_decay: Optional[float] = None, hparams_dict: Optional[Mapping[str, Any]] = None, emnist_model: Optional[str] = 'cnn', max_batches: Optional[int] = None): """Trains a model on EMNIST character recognition using a given optimizer. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). num_epochs: The number of training epochs. batch_size: The batch size, used for train, validation, and test. decay_epochs: The number of epochs of training before decaying the learning rate. If None, no decay occurs. lr_decay: The amount to decay the learning rate by after `decay_epochs` training epochs have occurred. hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. emnist_model: A string specifying the model used for character recognition. Can be one of `cnn` and `2nn`, corresponding to a CNN model and a densely connected 2-layer model (respectively). max_batches: If set to a positive integer, datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. """ train_dataset, eval_dataset = emnist_dataset.get_centralized_datasets( train_batch_size=batch_size, max_train_batches=max_batches, max_test_batches=max_batches, only_digits=False) if emnist_model == 'cnn': model = emnist_models.create_conv_dropout_model(only_digits=False) elif emnist_model == '2nn': model = emnist_models.create_two_hidden_layer_model(only_digits=False) else: raise ValueError( 'Cannot handle model flag [{!s}].'.format(emnist_model)) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), optimizer=optimizer, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) centralized_training_loop.run(keras_model=model, train_dataset=train_dataset, validation_dataset=eval_dataset, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict, decay_epochs=decay_epochs, lr_decay=lr_decay)
def run_centralized( optimizer: tf.keras.optimizers.Optimizer, image_size: int, num_epochs: int, batch_size: int, num_groups: int = 8, dataset_type: dataset.DatasetType = dataset.DatasetType.GLD23K, experiment_name: str = 'centralized_gld23k', root_output_dir: str = '/tmp/fedopt_guide', dropout_prob: Optional[float] = None, hparams_dict: Optional[Mapping[str, Any]] = None, max_batches: Optional[int] = None): """Trains a MobileNetV2 on the Google Landmark datasets. Args: optimizer: A `tf.keras.optimizers.Optimizer` used to perform training. image_size: The height and width of images after preprocessing. num_epochs: The number of training epochs. batch_size: The batch size, used for train and test. num_groups: The number of groups in the GroupNorm layers of MobilenetV2. dataset_type: A `dataset.DatasetType` specifying which dataset is used for experiments. experiment_name: The name of the experiment. Part of the output directory. root_output_dir: The top-level output directory for experiment runs. The `experiment_name` argument will be appended, and the directory will contain tensorboard logs, metrics written as CSVs, and a CSV of hyperparameter choices (if `hparams_dict` is used). dropout_prob: Probability of setting a weight to zero in the dropout layer of MobilenetV2. Must be in the range [0, 1). Setting it to None (default) or zero means no dropout. hparams_dict: A mapping with string keys representing the hyperparameters and their values. If not None, this is written to CSV. max_batches: If set to a positive integer, datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. """ train_data, test_data = dataset.get_centralized_datasets( image_size=image_size, batch_size=batch_size, dataset_type=dataset_type) num_classes, _ = dataset.get_dataset_stats(dataset_type) if max_batches and max_batches >= 1: train_data = train_data.take(max_batches) test_data = test_data.take(max_batches) if dropout_prob and (dropout_prob < 0 or dropout_prob >= 1): raise ValueError( f'Expected a value in [0, 1) for `dropout_prob`, found {dropout_prob}.' ) model = mobilenet_v2.create_mobilenet_v2(input_shape=(image_size, image_size, 3), num_groups=num_groups, num_classes=num_classes, dropout_prob=dropout_prob) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), optimizer=optimizer, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) centralized_training_loop.run(keras_model=model, train_dataset=train_data, validation_dataset=test_data, experiment_name=experiment_name, root_output_dir=root_output_dir, num_epochs=num_epochs, hparams_dict=hparams_dict)