class DistributionStrategyLstmModelCorrectnessTest( _DistributionStrategyRnnModelCorrectnessTest ): def _get_layer_class(self): if tf.__internal__.tf2.enabled(): if not tf.executing_eagerly(): self.skipTest( "LSTM v2 and legacy graph mode don't work together." ) return lstm.LSTM else: return lstm_v1.LSTM @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.test_combinations_for_embedding_model() + keras_correctness_test_base.multi_worker_mirrored_eager() ) def test_lstm_model_correctness( self, distribution, use_numpy, use_validation_data ): self.run_correctness_test(distribution, use_numpy, use_validation_data) @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.test_combinations_for_embedding_model() + keras_correctness_test_base.multi_worker_mirrored_eager() ) @test_utils.enable_v2_dtype_behavior def test_lstm_model_correctness_mixed_precision( self, distribution, use_numpy, use_validation_data ): if isinstance( distribution, ( tf.distribute.experimental.CentralStorageStrategy, tf.compat.v1.distribute.experimental.CentralStorageStrategy, ), ): self.skipTest( "CentralStorageStrategy is not supported by " "mixed precision." ) if isinstance( distribution, ( tf.distribute.experimental.TPUStrategy, tf.compat.v1.distribute.experimental.TPUStrategy, ), ): policy_name = "mixed_bfloat16" else: policy_name = "mixed_float16" with policy.policy_scope(policy_name): self.run_correctness_test( distribution, use_numpy, use_validation_data )
class DistributionStrategyEmbeddingModelCorrectnessTest( keras_correctness_test_base. TestDistributionStrategyEmbeddingModelCorrectnessBase # noqa: E501 ): def get_model( self, max_words=10, initial_weights=None, distribution=None, input_shapes=None, ): del input_shapes with keras_correctness_test_base.MaybeDistributionScope(distribution): word_ids = keras.layers.Input(shape=(max_words, ), dtype=np.int32, name="words") word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids) if self.use_distributed_dense: word_embed = keras.layers.TimeDistributed( keras.layers.Dense(4))(word_embed) avg = keras.layers.GlobalAveragePooling1D()(word_embed) preds = keras.layers.Dense(2, activation="softmax")(avg) model = keras.Model(inputs=[word_ids], outputs=[preds]) if initial_weights: model.set_weights(initial_weights) model.compile( optimizer=gradient_descent_keras.SGD(learning_rate=0.1), loss="sparse_categorical_crossentropy", metrics=["sparse_categorical_accuracy"], ) return model @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.test_combinations_for_embedding_model() + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_embedding_model_correctness(self, distribution, use_numpy, use_validation_data): self.use_distributed_dense = False self.run_correctness_test(distribution, use_numpy, use_validation_data) @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.test_combinations_for_embedding_model() + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_embedding_time_distributed_model_correctness( self, distribution, use_numpy, use_validation_data): self.use_distributed_dense = True self.run_correctness_test(distribution, use_numpy, use_validation_data)
class DistributionStrategyGruModelCorrectnessTest( _DistributionStrategyRnnModelCorrectnessTest): def _get_layer_class(self): if tf.__internal__.tf2.enabled(): if not tf.executing_eagerly(): self.skipTest("GRU v2 and legacy graph mode don't work together.") return rnn_v2.GRU else: return rnn_v1.GRU @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.test_combinations_for_embedding_model() + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_gru_model_correctness(self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data)
class DistributionStrategyCnnCorrectnessTest( keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): def get_model(self, initial_weights=None, distribution=None, input_shapes=None): del input_shapes with keras_correctness_test_base.MaybeDistributionScope(distribution): image = keras.layers.Input(shape=(28, 28, 3), name="image") c1 = keras.layers.Conv2D( name="conv1", filters=16, kernel_size=(3, 3), strides=(4, 4), kernel_regularizer=keras.regularizers.l2(1e-4), )(image) if self.with_batch_norm == "regular": c1 = keras.layers.BatchNormalization(name="bn1")(c1) elif self.with_batch_norm == "sync": # Test with parallel batch norms to verify all-reduce works OK. bn1 = keras.layers.SyncBatchNormalization(name="bn1")(c1) bn2 = keras.layers.SyncBatchNormalization(name="bn2")(c1) c1 = keras.layers.Add()([bn1, bn2]) c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1) logits = keras.layers.Dense(10, activation="softmax", name="pred")( keras.layers.Flatten()(c1)) model = keras.Model(inputs=[image], outputs=[logits]) if initial_weights: model.set_weights(initial_weights) model.compile( optimizer=gradient_descent.SGD(learning_rate=0.1), loss="sparse_categorical_crossentropy", metrics=["sparse_categorical_accuracy"], ) return model def _get_data(self, count, shape=(28, 28, 3), num_classes=10): centers = np.random.randn(num_classes, *shape) features = [] labels = [] for _ in range(count): label = np.random.randint(0, num_classes, size=1)[0] offset = np.random.normal(loc=0, scale=0.1, size=np.prod(shape)) offset = offset.reshape(shape) labels.append(label) features.append(centers[label] + offset) x = np.asarray(features, dtype=np.float32) y = np.asarray(labels, dtype=np.float32).reshape((count, 1)) return x, y def get_data(self): x_train, y_train = self._get_data( count=keras_correctness_test_base._GLOBAL_BATCH_SIZE * keras_correctness_test_base._EVAL_STEPS) x_predict = x_train return x_train, y_train, x_predict def get_data_with_partial_last_batch_eval(self): x_train, y_train = self._get_data(count=1280) x_eval, y_eval = self._get_data(count=1000) return x_train, y_train, x_eval, y_eval, x_eval @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.all_strategy_and_input_config_combinations( ) + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_cnn_correctness(self, distribution, use_numpy, use_validation_data): if (distribution == tf.__internal__.distribute.combinations. central_storage_strategy_with_gpu_and_cpu): self.skipTest("b/183958183") self.run_correctness_test(distribution, use_numpy, use_validation_data) @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.all_strategy_and_input_config_combinations( ) + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy, use_validation_data): self.run_correctness_test( distribution, use_numpy, use_validation_data, with_batch_norm="regular", ) @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.all_strategy_and_input_config_combinations( ) + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy, use_validation_data): if not tf.executing_eagerly(): self.skipTest("SyncBatchNorm is not enabled in graph mode.") self.run_correctness_test(distribution, use_numpy, use_validation_data, with_batch_norm="sync") @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base. all_strategy_and_input_config_combinations_eager() + keras_correctness_test_base.multi_worker_mirrored_eager() + keras_correctness_test_base. test_combinations_with_tpu_strategies_graph()) def test_cnn_correctness_with_partial_last_batch_eval( self, distribution, use_numpy, use_validation_data): self.run_correctness_test( distribution, use_numpy, use_validation_data, partial_last_batch=True, training_epochs=1, ) @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base. all_strategy_and_input_config_combinations_eager() + keras_correctness_test_base.multi_worker_mirrored_eager() + keras_correctness_test_base. test_combinations_with_tpu_strategies_graph()) def test_cnn_with_batch_norm_correctness_and_partial_last_batch_eval( self, distribution, use_numpy, use_validation_data): self.run_correctness_test( distribution, use_numpy, use_validation_data, with_batch_norm="regular", partial_last_batch=True, )
class DistributionStrategySiameseEmbeddingModelCorrectnessTest( keras_correctness_test_base. TestDistributionStrategyEmbeddingModelCorrectnessBase): def get_model( self, max_words=10, initial_weights=None, distribution=None, input_shapes=None, ): del input_shapes with keras_correctness_test_base.MaybeDistributionScope(distribution): word_ids_a = keras.layers.Input(shape=(max_words, ), dtype=np.int32, name="words_a") word_ids_b = keras.layers.Input(shape=(max_words, ), dtype=np.int32, name="words_b") def submodel(embedding, word_ids): word_embed = embedding(word_ids) rep = keras.layers.GlobalAveragePooling1D()(word_embed) return keras.Model(inputs=[word_ids], outputs=[rep]) word_embed = keras.layers.Embedding( input_dim=20, output_dim=10, input_length=max_words, embeddings_initializer=keras.initializers.RandomUniform(0, 1), ) a_rep = submodel(word_embed, word_ids_a).outputs[0] b_rep = submodel(word_embed, word_ids_b).outputs[0] sim = keras.layers.Dot(axes=1, normalize=True)([a_rep, b_rep]) model = keras.Model(inputs=[word_ids_a, word_ids_b], outputs=[sim]) if initial_weights: model.set_weights(initial_weights) # TODO(b/130808953): Switch back to the V1 optimizer after global_step # is made mirrored. model.compile( optimizer=gradient_descent_keras.SGD(learning_rate=0.1), loss="mse", metrics=["mse"], ) return model def get_data( self, count=(keras_correctness_test_base._GLOBAL_BATCH_SIZE * keras_correctness_test_base._EVAL_STEPS), min_words=5, max_words=10, max_word_id=19, num_classes=2, ): features_a, labels_a, _ = super().get_data(count, min_words, max_words, max_word_id, num_classes) features_b, labels_b, _ = super().get_data(count, min_words, max_words, max_word_id, num_classes) y_train = np.zeros((count, 1), dtype=np.float32) y_train[labels_a == labels_b] = 1.0 y_train[labels_a != labels_b] = -1.0 # TODO(b/123360757): Add tests for using list as inputs for multi-input # models. x_train = { "words_a": features_a, "words_b": features_b, } x_predict = x_train return x_train, y_train, x_predict @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.test_combinations_for_embedding_model() + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_siamese_embedding_model_correctness(self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data)
class TestDistributionStrategyDnnCorrectness( keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): def get_model(self, initial_weights=None, distribution=None, input_shapes=None): with keras_correctness_test_base.MaybeDistributionScope(distribution): # We add few non-linear layers to make it non-trivial. model = keras.Sequential() model.add( keras.layers.Dense(10, activation='relu', input_shape=(1, ))) model.add( keras.layers.Dense( 10, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-4))) model.add(keras.layers.Dense(10, activation='relu')) model.add(keras.layers.Dense(1)) if initial_weights: model.set_weights(initial_weights) model.compile(loss=keras.losses.mean_squared_error, optimizer=gradient_descent_keras.SGD(0.05), metrics=['mse']) return model def get_data(self): x_train = np.random.rand(9984, 1).astype('float32') y_train = 3 * x_train x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32) return x_train, y_train, x_predict def get_data_with_partial_last_batch(self): x_train = np.random.rand(10000, 1).astype('float32') y_train = 3 * x_train x_eval = np.random.rand(10000, 1).astype('float32') y_eval = 3 * x_eval x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32) return x_train, y_train, x_eval, y_eval, x_predict def get_data_with_partial_last_batch_eval(self): x_train = np.random.rand(9984, 1).astype('float32') y_train = 3 * x_train x_eval = np.random.rand(10000, 1).astype('float32') y_eval = 3 * x_eval x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32) return x_train, y_train, x_eval, y_eval, x_predict @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.all_strategy_and_input_config_combinations( ) + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_dnn_correctness(self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data) @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base. test_combinations_with_tpu_strategies_graph() + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_dnn_correctness_with_partial_last_batch_eval( self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data, partial_last_batch='eval') @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base. strategy_minus_tpu_and_input_config_combinations_eager() + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_dnn_correctness_with_partial_last_batch(self, distribution, use_numpy, use_validation_data): distribution.extended.experimental_enable_get_next_as_optional = True self.run_correctness_test(distribution, use_numpy, use_validation_data, partial_last_batch='train_and_eval', training_epochs=1) @tf.__internal__.distribute.combinations.generate( all_strategy_combinations_with_graph_mode()) def test_dnn_with_dynamic_learning_rate(self, distribution): self.run_dynamic_lr_test(distribution)
class TestDistributionStrategyDnnCorrectnessWithSubclassedModel( TestDistributionStrategyDnnCorrectness): def get_model(self, initial_weights=None, distribution=None, input_shapes=None): with keras_correctness_test_base.MaybeDistributionScope(distribution): model = SubclassedModel(initial_weights, input_shapes) model.compile(loss=keras.losses.mean_squared_error, optimizer=gradient_descent_keras.SGD(0.05), metrics=['mse']) return model @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.all_strategy_and_input_config_combinations( ) + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_dnn_correctness(self, distribution, use_numpy, use_validation_data): if (tf.executing_eagerly()) or is_default_strategy(distribution): self.run_correctness_test(distribution, use_numpy, use_validation_data) elif (backend.is_tpu_strategy(distribution) and not tf.executing_eagerly()): with self.assertRaisesRegex( ValueError, 'Expected `model` argument to be a functional `Model` instance, ' 'but got a subclass model instead.'): self.run_correctness_test(distribution, use_numpy, use_validation_data) else: with self.assertRaisesRegex( ValueError, 'We currently do not support distribution strategy with a ' '`Sequential` model that is created without `input_shape`/' '`input_dim` set in its first layer or a subclassed model.' ): self.run_correctness_test(distribution, use_numpy, use_validation_data) @tf.__internal__.distribute.combinations.generate( all_strategy_combinations_with_graph_mode()) def test_dnn_with_dynamic_learning_rate(self, distribution): if ((tf.executing_eagerly() and not backend.is_tpu_strategy(distribution)) or is_default_strategy(distribution)): self.run_dynamic_lr_test(distribution) elif backend.is_tpu_strategy(distribution): with self.assertRaisesRegex( ValueError, 'Expected `model` argument to be a functional `Model` instance, ' 'but got a subclass model instead.'): self.run_dynamic_lr_test(distribution) else: with self.assertRaisesRegex( ValueError, 'We currently do not support distribution strategy with a ' '`Sequential` model that is created without `input_shape`/' '`input_dim` set in its first layer or a subclassed model.' ): self.run_dynamic_lr_test(distribution) @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base. test_combinations_with_tpu_strategies_graph()) def test_dnn_correctness_with_partial_last_batch_eval( self, distribution, use_numpy, use_validation_data): with self.assertRaisesRegex( ValueError, 'Expected `model` argument to be a functional `Model` instance, ' 'but got a subclass model instead.'): self.run_correctness_test(distribution, use_numpy, use_validation_data, partial_last_batch='eval')