class DistributionStrategyStatefulLstmModelCorrectnessTest(
        keras_correctness_test_base.
        TestDistributionStrategyEmbeddingModelCorrectnessBase):
    def get_model(self, max_words=10, initial_weights=None, distribution=None):
        batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE

        with keras_correctness_test_base.MaybeDistributionScope(distribution):
            word_ids = keras.layers.Input(shape=(max_words, ),
                                          batch_size=batch_size,
                                          dtype=np.int32,
                                          name='words')
            word_embed = keras.layers.Embedding(input_dim=20,
                                                output_dim=10)(word_ids)
            lstm_embed = keras.layers.LSTM(units=4,
                                           return_sequences=False,
                                           stateful=True)(word_embed)

            preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
            model = keras.Model(inputs=[word_ids], outputs=[preds])

            if initial_weights:
                model.set_weights(initial_weights)

            model.compile(optimizer=gradient_descent.GradientDescentOptimizer(
                learning_rate=0.1),
                          loss='sparse_categorical_crossentropy',
                          metrics=['sparse_categorical_accuracy'])
        return model

    @combinations.generate(test_combinations_for_stateful_embedding_model())
    def test_stateful_lstm_model_correctness(self, distribution, use_numpy,
                                             use_validation_data):
        self.run_correctness_test(distribution,
                                  use_numpy,
                                  use_validation_data,
                                  is_stateful_model=True)

    @combinations.generate(
        keras_correctness_test_base.test_combinations_with_tpu_strategies())
    def test_incorrectly_use_multiple_cores_for_stateful_lstm_model(
            self, distribution, use_numpy, use_validation_data):
        with self.assertRaisesRegexp(
                ValueError, 'Single core must be used for computation '
                'on stateful models. Consider adding '
                '`device_assignment` parameter to '
                'TPUStrategy'):
            self.run_correctness_test(distribution,
                                      use_numpy,
                                      use_validation_data,
                                      is_stateful_model=True)
Exemplo n.º 2
0
class TestDistributionStrategyDnnCorrectness(
        keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
    def get_model(self,
                  initial_weights=None,
                  distribution=None,
                  input_shapes=None):
        with keras_correctness_test_base.MaybeDistributionScope(distribution):
            # We add few non-linear layers to make it non-trivial.
            model = keras.Sequential()
            model.add(
                keras.layers.Dense(10, activation='relu', input_shape=(1, )))
            model.add(
                keras.layers.Dense(
                    10,
                    activation='relu',
                    kernel_regularizer=keras.regularizers.l2(1e-4)))
            model.add(keras.layers.Dense(10, activation='relu'))
            model.add(keras.layers.Dense(1))

            if initial_weights:
                model.set_weights(initial_weights)

            model.compile(loss=keras.losses.mean_squared_error,
                          optimizer=gradient_descent_keras.SGD(0.05),
                          metrics=['mse'])
            return model

    def get_data(self):
        x_train = np.random.rand(9984, 1).astype('float32')
        y_train = 3 * x_train
        x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32)
        return x_train, y_train, x_predict

    def get_data_with_partial_last_batch(self):
        x_train = np.random.rand(10000, 1).astype('float32')
        y_train = 3 * x_train
        x_eval = np.random.rand(10000, 1).astype('float32')
        y_eval = 3 * x_eval
        x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32)
        return x_train, y_train, x_eval, y_eval, x_predict

    def get_data_with_partial_last_batch_eval(self):
        x_train = np.random.rand(9984, 1).astype('float32')
        y_train = 3 * x_train
        x_eval = np.random.rand(10000, 1).astype('float32')
        y_eval = 3 * x_eval
        x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32)
        return x_train, y_train, x_eval, y_eval, x_predict

    @combinations.generate(keras_correctness_test_base.
                           all_strategy_and_input_config_combinations())
    def test_dnn_correctness(self, distribution, use_numpy,
                             use_validation_data):
        self.run_correctness_test(distribution, use_numpy, use_validation_data)

    @combinations.generate(
        keras_correctness_test_base.test_combinations_with_tpu_strategies())
    def test_dnn_correctness_with_partial_last_batch_eval(
            self, distribution, use_numpy, use_validation_data):
        self.run_correctness_test(distribution,
                                  use_numpy,
                                  use_validation_data,
                                  partial_last_batch='eval')

    @combinations.generate(
        keras_correctness_test_base.
        strategy_minus_tpu_and_input_config_combinations_eager())
    def test_dnn_correctness_with_partial_last_batch(self, distribution,
                                                     use_numpy,
                                                     use_validation_data):
        distribution.extended.experimental_enable_get_next_as_optional = True
        self.run_correctness_test(distribution,
                                  use_numpy,
                                  use_validation_data,
                                  partial_last_batch='train_and_eval',
                                  training_epochs=1)

    @combinations.generate(all_strategy_combinations_with_graph_mode())
    def test_dnn_with_dynamic_learning_rate(self, distribution):
        self.run_dynamic_lr_test(distribution)
Exemplo n.º 3
0
class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
        TestDistributionStrategyDnnCorrectness):
    def get_model(self,
                  initial_weights=None,
                  distribution=None,
                  input_shapes=None):
        with keras_correctness_test_base.MaybeDistributionScope(distribution):
            model = SubclassedModel(initial_weights, input_shapes)

            model.compile(loss=keras.losses.mean_squared_error,
                          optimizer=gradient_descent_keras.SGD(0.05),
                          metrics=['mse'])
            return model

    @combinations.generate(keras_correctness_test_base.
                           all_strategy_and_input_config_combinations())
    def test_dnn_correctness(self, distribution, use_numpy,
                             use_validation_data):
        if (context.executing_eagerly()) or is_default_strategy(distribution):
            self.run_correctness_test(distribution, use_numpy,
                                      use_validation_data)
        elif K.is_tpu_strategy(
                distribution) and not context.executing_eagerly():
            with self.assertRaisesRegexp(
                    ValueError,
                    'Expected `model` argument to be a functional `Model` instance, '
                    'but got a subclass model instead.'):
                self.run_correctness_test(distribution, use_numpy,
                                          use_validation_data)
        else:
            with self.assertRaisesRegexp(
                    ValueError,
                    'We currently do not support distribution strategy with a '
                    '`Sequential` model that is created without `input_shape`/'
                    '`input_dim` set in its first layer or a subclassed model.'
            ):
                self.run_correctness_test(distribution, use_numpy,
                                          use_validation_data)

    @combinations.generate(all_strategy_combinations_with_graph_mode())
    def test_dnn_with_dynamic_learning_rate(self, distribution):
        if ((context.executing_eagerly()
             and not K.is_tpu_strategy(distribution))
                or is_default_strategy(distribution)):
            self.run_dynamic_lr_test(distribution)
        elif K.is_tpu_strategy(distribution):
            with self.assertRaisesRegexp(
                    ValueError,
                    'Expected `model` argument to be a functional `Model` instance, '
                    'but got a subclass model instead.'):
                self.run_dynamic_lr_test(distribution)
        else:
            with self.assertRaisesRegexp(
                    ValueError,
                    'We currently do not support distribution strategy with a '
                    '`Sequential` model that is created without `input_shape`/'
                    '`input_dim` set in its first layer or a subclassed model.'
            ):
                self.run_dynamic_lr_test(distribution)

    @combinations.generate(
        keras_correctness_test_base.test_combinations_with_tpu_strategies())
    def test_dnn_correctness_with_partial_last_batch_eval(
            self, distribution, use_numpy, use_validation_data):
        with self.assertRaisesRegexp(
                ValueError,
                'Expected `model` argument to be a functional `Model` instance, '
                'but got a subclass model instead.'):
            self.run_correctness_test(distribution,
                                      use_numpy,
                                      use_validation_data,
                                      partial_last_batch='eval')
class DistributionStrategyStatefulLstmModelCorrectnessTest(
        keras_correctness_test_base.
        TestDistributionStrategyEmbeddingModelCorrectnessBase):
    def get_model(self,
                  max_words=10,
                  initial_weights=None,
                  distribution=None,
                  experimental_run_tf_function=None,
                  input_shapes=None):
        del input_shapes
        batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE

        with keras_correctness_test_base.MaybeDistributionScope(distribution):
            word_ids = keras.layers.Input(shape=(max_words, ),
                                          batch_size=batch_size,
                                          dtype=np.int32,
                                          name='words')
            word_embed = keras.layers.Embedding(input_dim=20,
                                                output_dim=10)(word_ids)
            lstm_embed = keras.layers.LSTM(units=4,
                                           return_sequences=False,
                                           stateful=True)(word_embed)

            preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
            model = keras.Model(inputs=[word_ids], outputs=[preds])

            if initial_weights:
                model.set_weights(initial_weights)

            optimizer_fn = gradient_descent_keras.SGD

            model.compile(optimizer=optimizer_fn(learning_rate=0.1),
                          loss='sparse_categorical_crossentropy',
                          metrics=['sparse_categorical_accuracy'])
        return model

    # TODO(jhseu): Disabled to fix b/130808953. Need to investigate why it
    # doesn't work and enable for DistributionStrategy more generally.
    @combinations.generate(test_combinations_for_stateful_embedding_model())
    def disabled_test_stateful_lstm_model_correctness(
            self, distribution, use_numpy, use_validation_data,
            experimental_run_tf_function):
        self.run_correctness_test(
            distribution,
            use_numpy,
            use_validation_data,
            is_stateful_model=True,
            experimental_run_tf_function=experimental_run_tf_function)

    @combinations.generate(
        combinations.times(
            keras_correctness_test_base.test_combinations_with_tpu_strategies(
            ),
            combinations.combine(experimental_run_tf_function=[True, False])))
    def test_incorrectly_use_multiple_cores_for_stateful_lstm_model(
            self, distribution, use_numpy, use_validation_data,
            experimental_run_tf_function):
        with self.assertRaisesRegexp(
                ValueError,
                'Single core must be used for computation on stateful models. Consider '
                'adding `device_assignment` parameter to TPUStrategy'):
            self.run_correctness_test(
                distribution,
                use_numpy,
                use_validation_data,
                is_stateful_model=True,
                experimental_run_tf_function=experimental_run_tf_function)
Exemplo n.º 5
0
class DistributionStrategyCnnCorrectnessTest(
        keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
    def get_model(self,
                  initial_weights=None,
                  distribution=None,
                  input_shapes=None):
        del input_shapes
        with keras_correctness_test_base.MaybeDistributionScope(distribution):
            image = keras.layers.Input(shape=(28, 28, 3), name='image')
            c1 = keras.layers.Conv2D(
                name='conv1',
                filters=16,
                kernel_size=(3, 3),
                strides=(4, 4),
                kernel_regularizer=keras.regularizers.l2(1e-4))(image)
            if self.with_batch_norm == 'regular':
                c1 = keras.layers.BatchNormalization(name='bn1')(c1)
            elif self.with_batch_norm == 'sync':
                c1 = keras.layers.SyncBatchNormalization(name='bn1')(c1)
            c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1)
            logits = keras.layers.Dense(10, activation='softmax', name='pred')(
                keras.layers.Flatten()(c1))
            model = keras.Model(inputs=[image], outputs=[logits])

            if initial_weights:
                model.set_weights(initial_weights)

            model.compile(optimizer=gradient_descent.SGD(learning_rate=0.1),
                          loss='sparse_categorical_crossentropy',
                          metrics=['sparse_categorical_accuracy'])

        return model

    def _get_data(self, count, shape=(28, 28, 3), num_classes=10):
        centers = np.random.randn(num_classes, *shape)

        features = []
        labels = []
        for _ in range(count):
            label = np.random.randint(0, num_classes, size=1)[0]
            offset = np.random.normal(loc=0, scale=0.1, size=np.prod(shape))
            offset = offset.reshape(shape)
            labels.append(label)
            features.append(centers[label] + offset)

        x = np.asarray(features, dtype=np.float32)
        y = np.asarray(labels, dtype=np.float32).reshape((count, 1))
        return x, y

    def get_data(self):
        x_train, y_train = self._get_data(
            count=keras_correctness_test_base._GLOBAL_BATCH_SIZE *
            keras_correctness_test_base._EVAL_STEPS)
        x_predict = x_train
        return x_train, y_train, x_predict

    def get_data_with_partial_last_batch_eval(self):
        x_train, y_train = self._get_data(count=1280)
        x_eval, y_eval = self._get_data(count=1000)
        return x_train, y_train, x_eval, y_eval, x_eval

    @ds_combinations.generate(keras_correctness_test_base.
                              all_strategy_and_input_config_combinations())
    def test_cnn_correctness(self, distribution, use_numpy,
                             use_validation_data):
        self.run_correctness_test(distribution, use_numpy, use_validation_data)

    @ds_combinations.generate(keras_correctness_test_base.
                              all_strategy_and_input_config_combinations())
    def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy,
                                             use_validation_data):
        self.skipTest('Flakily times out, b/134670856')
        self.run_correctness_test(distribution,
                                  use_numpy,
                                  use_validation_data,
                                  with_batch_norm='regular')

    @ds_combinations.generate(keras_correctness_test_base.
                              all_strategy_and_input_config_combinations())
    def test_cnn_with_sync_batch_norm_correctness(self, distribution,
                                                  use_numpy,
                                                  use_validation_data):
        if not context.executing_eagerly():
            self.skipTest('SyncBatchNorm is not enabled in graph mode.')

        self.run_correctness_test(distribution,
                                  use_numpy,
                                  use_validation_data,
                                  with_batch_norm='sync')

    @ds_combinations.generate(
        keras_correctness_test_base.test_combinations_with_tpu_strategies() +
        keras_correctness_test_base.
        strategy_minus_tpu_and_input_config_combinations_eager())
    def test_cnn_correctness_with_partial_last_batch_eval(
            self, distribution, use_numpy, use_validation_data):
        self.run_correctness_test(distribution,
                                  use_numpy,
                                  use_validation_data,
                                  partial_last_batch=True,
                                  training_epochs=1)

    @ds_combinations.generate(
        keras_correctness_test_base.test_combinations_with_tpu_strategies() +
        keras_correctness_test_base.
        strategy_minus_tpu_and_input_config_combinations_eager())
    def test_cnn_with_batch_norm_correctness_and_partial_last_batch_eval(
            self, distribution, use_numpy, use_validation_data):
        self.run_correctness_test(distribution,
                                  use_numpy,
                                  use_validation_data,
                                  with_batch_norm='regular',
                                  partial_last_batch=True)