class DistributionStrategyStatefulLstmModelCorrectnessTest( keras_correctness_test_base. TestDistributionStrategyEmbeddingModelCorrectnessBase): def get_model(self, max_words=10, initial_weights=None, distribution=None): batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE with keras_correctness_test_base.MaybeDistributionScope(distribution): word_ids = keras.layers.Input(shape=(max_words, ), batch_size=batch_size, dtype=np.int32, name='words') word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids) lstm_embed = keras.layers.LSTM(units=4, return_sequences=False, stateful=True)(word_embed) preds = keras.layers.Dense(2, activation='softmax')(lstm_embed) model = keras.Model(inputs=[word_ids], outputs=[preds]) if initial_weights: model.set_weights(initial_weights) model.compile(optimizer=gradient_descent.GradientDescentOptimizer( learning_rate=0.1), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) return model @combinations.generate(test_combinations_for_stateful_embedding_model()) def test_stateful_lstm_model_correctness(self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data, is_stateful_model=True) @combinations.generate( keras_correctness_test_base.test_combinations_with_tpu_strategies()) def test_incorrectly_use_multiple_cores_for_stateful_lstm_model( self, distribution, use_numpy, use_validation_data): with self.assertRaisesRegexp( ValueError, 'Single core must be used for computation ' 'on stateful models. Consider adding ' '`device_assignment` parameter to ' 'TPUStrategy'): self.run_correctness_test(distribution, use_numpy, use_validation_data, is_stateful_model=True)
class TestDistributionStrategyDnnCorrectness( keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): def get_model(self, initial_weights=None, distribution=None, input_shapes=None): with keras_correctness_test_base.MaybeDistributionScope(distribution): # We add few non-linear layers to make it non-trivial. model = keras.Sequential() model.add( keras.layers.Dense(10, activation='relu', input_shape=(1, ))) model.add( keras.layers.Dense( 10, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-4))) model.add(keras.layers.Dense(10, activation='relu')) model.add(keras.layers.Dense(1)) if initial_weights: model.set_weights(initial_weights) model.compile(loss=keras.losses.mean_squared_error, optimizer=gradient_descent_keras.SGD(0.05), metrics=['mse']) return model def get_data(self): x_train = np.random.rand(9984, 1).astype('float32') y_train = 3 * x_train x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32) return x_train, y_train, x_predict def get_data_with_partial_last_batch(self): x_train = np.random.rand(10000, 1).astype('float32') y_train = 3 * x_train x_eval = np.random.rand(10000, 1).astype('float32') y_eval = 3 * x_eval x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32) return x_train, y_train, x_eval, y_eval, x_predict def get_data_with_partial_last_batch_eval(self): x_train = np.random.rand(9984, 1).astype('float32') y_train = 3 * x_train x_eval = np.random.rand(10000, 1).astype('float32') y_eval = 3 * x_eval x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32) return x_train, y_train, x_eval, y_eval, x_predict @combinations.generate(keras_correctness_test_base. all_strategy_and_input_config_combinations()) def test_dnn_correctness(self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data) @combinations.generate( keras_correctness_test_base.test_combinations_with_tpu_strategies()) def test_dnn_correctness_with_partial_last_batch_eval( self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data, partial_last_batch='eval') @combinations.generate( keras_correctness_test_base. strategy_minus_tpu_and_input_config_combinations_eager()) def test_dnn_correctness_with_partial_last_batch(self, distribution, use_numpy, use_validation_data): distribution.extended.experimental_enable_get_next_as_optional = True self.run_correctness_test(distribution, use_numpy, use_validation_data, partial_last_batch='train_and_eval', training_epochs=1) @combinations.generate(all_strategy_combinations_with_graph_mode()) def test_dnn_with_dynamic_learning_rate(self, distribution): self.run_dynamic_lr_test(distribution)
class TestDistributionStrategyDnnCorrectnessWithSubclassedModel( TestDistributionStrategyDnnCorrectness): def get_model(self, initial_weights=None, distribution=None, input_shapes=None): with keras_correctness_test_base.MaybeDistributionScope(distribution): model = SubclassedModel(initial_weights, input_shapes) model.compile(loss=keras.losses.mean_squared_error, optimizer=gradient_descent_keras.SGD(0.05), metrics=['mse']) return model @combinations.generate(keras_correctness_test_base. all_strategy_and_input_config_combinations()) def test_dnn_correctness(self, distribution, use_numpy, use_validation_data): if (context.executing_eagerly()) or is_default_strategy(distribution): self.run_correctness_test(distribution, use_numpy, use_validation_data) elif K.is_tpu_strategy( distribution) and not context.executing_eagerly(): with self.assertRaisesRegexp( ValueError, 'Expected `model` argument to be a functional `Model` instance, ' 'but got a subclass model instead.'): self.run_correctness_test(distribution, use_numpy, use_validation_data) else: with self.assertRaisesRegexp( ValueError, 'We currently do not support distribution strategy with a ' '`Sequential` model that is created without `input_shape`/' '`input_dim` set in its first layer or a subclassed model.' ): self.run_correctness_test(distribution, use_numpy, use_validation_data) @combinations.generate(all_strategy_combinations_with_graph_mode()) def test_dnn_with_dynamic_learning_rate(self, distribution): if ((context.executing_eagerly() and not K.is_tpu_strategy(distribution)) or is_default_strategy(distribution)): self.run_dynamic_lr_test(distribution) elif K.is_tpu_strategy(distribution): with self.assertRaisesRegexp( ValueError, 'Expected `model` argument to be a functional `Model` instance, ' 'but got a subclass model instead.'): self.run_dynamic_lr_test(distribution) else: with self.assertRaisesRegexp( ValueError, 'We currently do not support distribution strategy with a ' '`Sequential` model that is created without `input_shape`/' '`input_dim` set in its first layer or a subclassed model.' ): self.run_dynamic_lr_test(distribution) @combinations.generate( keras_correctness_test_base.test_combinations_with_tpu_strategies()) def test_dnn_correctness_with_partial_last_batch_eval( self, distribution, use_numpy, use_validation_data): with self.assertRaisesRegexp( ValueError, 'Expected `model` argument to be a functional `Model` instance, ' 'but got a subclass model instead.'): self.run_correctness_test(distribution, use_numpy, use_validation_data, partial_last_batch='eval')
class DistributionStrategyStatefulLstmModelCorrectnessTest( keras_correctness_test_base. TestDistributionStrategyEmbeddingModelCorrectnessBase): def get_model(self, max_words=10, initial_weights=None, distribution=None, experimental_run_tf_function=None, input_shapes=None): del input_shapes batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE with keras_correctness_test_base.MaybeDistributionScope(distribution): word_ids = keras.layers.Input(shape=(max_words, ), batch_size=batch_size, dtype=np.int32, name='words') word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids) lstm_embed = keras.layers.LSTM(units=4, return_sequences=False, stateful=True)(word_embed) preds = keras.layers.Dense(2, activation='softmax')(lstm_embed) model = keras.Model(inputs=[word_ids], outputs=[preds]) if initial_weights: model.set_weights(initial_weights) optimizer_fn = gradient_descent_keras.SGD model.compile(optimizer=optimizer_fn(learning_rate=0.1), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) return model # TODO(jhseu): Disabled to fix b/130808953. Need to investigate why it # doesn't work and enable for DistributionStrategy more generally. @combinations.generate(test_combinations_for_stateful_embedding_model()) def disabled_test_stateful_lstm_model_correctness( self, distribution, use_numpy, use_validation_data, experimental_run_tf_function): self.run_correctness_test( distribution, use_numpy, use_validation_data, is_stateful_model=True, experimental_run_tf_function=experimental_run_tf_function) @combinations.generate( combinations.times( keras_correctness_test_base.test_combinations_with_tpu_strategies( ), combinations.combine(experimental_run_tf_function=[True, False]))) def test_incorrectly_use_multiple_cores_for_stateful_lstm_model( self, distribution, use_numpy, use_validation_data, experimental_run_tf_function): with self.assertRaisesRegexp( ValueError, 'Single core must be used for computation on stateful models. Consider ' 'adding `device_assignment` parameter to TPUStrategy'): self.run_correctness_test( distribution, use_numpy, use_validation_data, is_stateful_model=True, experimental_run_tf_function=experimental_run_tf_function)
class DistributionStrategyCnnCorrectnessTest( keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): def get_model(self, initial_weights=None, distribution=None, input_shapes=None): del input_shapes with keras_correctness_test_base.MaybeDistributionScope(distribution): image = keras.layers.Input(shape=(28, 28, 3), name='image') c1 = keras.layers.Conv2D( name='conv1', filters=16, kernel_size=(3, 3), strides=(4, 4), kernel_regularizer=keras.regularizers.l2(1e-4))(image) if self.with_batch_norm == 'regular': c1 = keras.layers.BatchNormalization(name='bn1')(c1) elif self.with_batch_norm == 'sync': c1 = keras.layers.SyncBatchNormalization(name='bn1')(c1) c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1) logits = keras.layers.Dense(10, activation='softmax', name='pred')( keras.layers.Flatten()(c1)) model = keras.Model(inputs=[image], outputs=[logits]) if initial_weights: model.set_weights(initial_weights) model.compile(optimizer=gradient_descent.SGD(learning_rate=0.1), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) return model def _get_data(self, count, shape=(28, 28, 3), num_classes=10): centers = np.random.randn(num_classes, *shape) features = [] labels = [] for _ in range(count): label = np.random.randint(0, num_classes, size=1)[0] offset = np.random.normal(loc=0, scale=0.1, size=np.prod(shape)) offset = offset.reshape(shape) labels.append(label) features.append(centers[label] + offset) x = np.asarray(features, dtype=np.float32) y = np.asarray(labels, dtype=np.float32).reshape((count, 1)) return x, y def get_data(self): x_train, y_train = self._get_data( count=keras_correctness_test_base._GLOBAL_BATCH_SIZE * keras_correctness_test_base._EVAL_STEPS) x_predict = x_train return x_train, y_train, x_predict def get_data_with_partial_last_batch_eval(self): x_train, y_train = self._get_data(count=1280) x_eval, y_eval = self._get_data(count=1000) return x_train, y_train, x_eval, y_eval, x_eval @ds_combinations.generate(keras_correctness_test_base. all_strategy_and_input_config_combinations()) def test_cnn_correctness(self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data) @ds_combinations.generate(keras_correctness_test_base. all_strategy_and_input_config_combinations()) def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy, use_validation_data): self.skipTest('Flakily times out, b/134670856') self.run_correctness_test(distribution, use_numpy, use_validation_data, with_batch_norm='regular') @ds_combinations.generate(keras_correctness_test_base. all_strategy_and_input_config_combinations()) def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy, use_validation_data): if not context.executing_eagerly(): self.skipTest('SyncBatchNorm is not enabled in graph mode.') self.run_correctness_test(distribution, use_numpy, use_validation_data, with_batch_norm='sync') @ds_combinations.generate( keras_correctness_test_base.test_combinations_with_tpu_strategies() + keras_correctness_test_base. strategy_minus_tpu_and_input_config_combinations_eager()) def test_cnn_correctness_with_partial_last_batch_eval( self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data, partial_last_batch=True, training_epochs=1) @ds_combinations.generate( keras_correctness_test_base.test_combinations_with_tpu_strategies() + keras_correctness_test_base. strategy_minus_tpu_and_input_config_combinations_eager()) def test_cnn_with_batch_norm_correctness_and_partial_last_batch_eval( self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data, with_batch_norm='regular', partial_last_batch=True)