예제 #1
0
    def test_save_load_trackable(self, distribution, optimizer):
        # TODO(b/123533246): Enable the test for TPU once bug is fixed
        if (isinstance(
                distribution,
            (
                tf.distribute.experimental.TPUStrategy,
                tf.compat.v1.distribute.experimental.TPUStrategy,
            ),
        ) and distribution.extended.steps_per_run > 1):
            self.skipTest(
                "MultiStep TPU Strategy deadlocks with optimizer restore.")
        with self.cached_session():
            dataset = keras_test_lib.get_dataset(distribution)
            with distribution.scope():
                model = keras_test_lib.get_model()
                model.compile(optimizer(), "mse")
                model.fit(dataset, epochs=1, steps_per_epoch=1)

                weights_file = tempfile.mktemp()
                model.save_weights(weights_file)

                model_2 = keras_test_lib.get_model()
                model_2.compile(optimizer(), "mse")
                model_2.load_weights(weights_file)
                model_2.predict(
                    keras_test_lib.get_predict_dataset(distribution), steps=2)
                model_2.fit(dataset, epochs=1, steps_per_epoch=1)
예제 #2
0
    def test_save_load_h5(self, distribution, optimizer):
        with self.cached_session():
            dataset = keras_test_lib.get_dataset(distribution)
            with distribution.scope():
                model = keras_test_lib.get_model()
                model.compile(optimizer(), 'mse')
                model.fit(dataset, epochs=1, steps_per_epoch=1)

                weights_file = tempfile.mktemp('.h5')
                model.save_weights(weights_file)

                model_2 = keras_test_lib.get_model()
                model_2.compile(optimizer(), 'mse')
                model_2.load_weights(weights_file)
                model_2.predict(
                    keras_test_lib.get_predict_dataset(distribution), steps=2)
                model_2.fit(dataset, epochs=1, steps_per_epoch=1)
예제 #3
0
    def test_callbacks_in_predict(self, distribution):
        with distribution.scope():
            model = keras_test_lib.get_model()
            model.compile(optimizer='sgd', loss='mse', metrics=['mae'])

        dataset = keras_test_lib.get_dataset(distribution)
        counter = Counter()

        model.predict(keras_test_lib.get_predict_dataset(dataset),
                      steps=5,
                      callbacks=[counter])

        self.assertDictEqual(
            counter.method_counts, {
                'on_predict_batch_begin': 5,
                'on_predict_batch_end': 5,
                'on_predict_begin': 1,
                'on_predict_end': 1
            })
예제 #4
0
    def test_callbacks_in_predict(self, distribution):
        with distribution.scope():
            model = keras_test_lib.get_model()
            model.compile(optimizer="sgd", loss="mse", metrics=["mae"])

        dataset = keras_test_lib.get_dataset(distribution)
        counter = Counter()

        model.predict(
            keras_test_lib.get_predict_dataset(dataset),
            steps=5,
            callbacks=[counter],
        )

        self.assertDictEqual(
            counter.method_counts,
            {
                "on_predict_batch_begin": 5,
                "on_predict_batch_end": 5,
                "on_predict_begin": 1,
                "on_predict_end": 1,
            },
        )