Exemplo n.º 1
0
class TestDistributionStrategyWithNormalizationLayer(test.TestCase,
                                                     parameterized.TestCase):
    @combinations.generate(
        combinations.times(keras_test_lib.all_strategy_combinations(),
                           combinations.combine(fused=[True, False])))
    def test_batchnorm_correctness(self, distribution, fused):
        with self.cached_session():
            with distribution.scope():
                model = keras.models.Sequential()
                norm = keras.layers.BatchNormalization(input_shape=(10, ),
                                                       momentum=0.8,
                                                       fused=fused)
                model.add(norm)
                model.compile(
                    loss='mse',
                    optimizer=gradient_descent.GradientDescentOptimizer(0.01))

            # centered on 5.0, variance 10.0
            x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
            x = x.astype('float32')
            dataset = dataset_ops.Dataset.from_tensor_slices((x, x))
            dataset = dataset.repeat(100)
            dataset = keras_test_lib.batch_wrapper(dataset, 32, distribution)

            predict_dataset = dataset_ops.Dataset.from_tensor_slices(x)
            predict_dataset = predict_dataset.repeat(100)
            predict_dataset = keras_test_lib.batch_wrapper(
                predict_dataset, 32, distribution)

            model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
            out = model.predict(predict_dataset, steps=2)
            out -= keras.backend.eval(norm.beta)
            out /= keras.backend.eval(norm.gamma)
            np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
            np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
Exemplo n.º 2
0
class TestDistributionStrategyWithCallbacks(test.TestCase,
                                            parameterized.TestCase):
    @combinations.generate(keras_test_lib.all_strategy_combinations())
    def test_callbacks_in_fit(self, distribution):
        with distribution.scope():
            model = keras_test_lib.get_model()
            model.compile(optimizer='sgd', loss='mse', metrics=['mae'])

        dataset = keras_test_lib.get_dataset(distribution)
        counter = Counter()

        epochs = 2
        steps_per_epoch = 5
        validation_steps = 3

        model.fit(dataset,
                  epochs=epochs,
                  steps_per_epoch=steps_per_epoch,
                  verbose=0,
                  validation_data=dataset,
                  validation_steps=validation_steps,
                  callbacks=[counter])

        if isinstance(distribution, tpu_strategy.TPUStrategy):
            # TPU Strategy can have multi step training, from extended.steps_per_run
            # if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch
            steps_per_run = distribution.extended.steps_per_run
            num_batch_call_per_epoch = steps_per_epoch // steps_per_run
            if steps_per_epoch % steps_per_run:
                num_batch_call_per_epoch += 1
        else:
            num_batch_call_per_epoch = steps_per_epoch

        self.assertDictEqual(
            counter.method_counts, {
                'on_batch_begin': epochs * num_batch_call_per_epoch,
                'on_batch_end': epochs * num_batch_call_per_epoch,
                'on_epoch_begin': epochs,
                'on_epoch_end': epochs,
                'on_test_batch_begin': epochs * validation_steps,
                'on_test_batch_end': epochs * validation_steps,
                'on_test_begin': epochs,
                'on_test_end': epochs,
                'on_train_batch_begin': epochs * num_batch_call_per_epoch,
                'on_train_batch_end': epochs * num_batch_call_per_epoch,
                'on_train_begin': 1,
                'on_train_end': 1
            })

    @combinations.generate(keras_test_lib.all_strategy_combinations())
    def test_callbacks_in_eval(self, distribution):
        with distribution.scope():
            model = keras_test_lib.get_model()
            model.compile(optimizer='sgd', loss='mse', metrics=['mae'])

        dataset = keras_test_lib.get_dataset(distribution)
        counter = Counter()

        model.evaluate(dataset, steps=5, callbacks=[counter])

        self.assertDictEqual(
            counter.method_counts, {
                'on_test_batch_begin': 5,
                'on_test_batch_end': 5,
                'on_test_begin': 1,
                'on_test_end': 1
            })

    @combinations.generate(keras_test_lib.all_strategy_combinations())
    def test_callbacks_in_predict(self, distribution):
        with distribution.scope():
            model = keras_test_lib.get_model()
            model.compile(optimizer='sgd', loss='mse', metrics=['mae'])

        dataset = keras_test_lib.get_dataset(distribution)
        counter = Counter()

        model.predict(keras_test_lib.get_predict_dataset(dataset),
                      steps=5,
                      callbacks=[counter])

        self.assertDictEqual(
            counter.method_counts, {
                'on_predict_batch_begin': 5,
                'on_predict_batch_end': 5,
                'on_predict_begin': 1,
                'on_predict_end': 1
            })