Esempio n. 1
0
  def test_save_load_trackable(self, distribution, optimizer):
    # TODO(b/123533246): Enable the test for TPU once bug is fixed
    if (isinstance(distribution,
                   (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)) and
        distribution.extended.steps_per_run > 1):
      self.skipTest('MultiStep TPU Strategy deadlocks with optimizer restore.')
    with self.cached_session():
      dataset = keras_test_lib.get_dataset(distribution)
      with distribution.scope():
        model = keras_test_lib.get_model()
        model.compile(
            optimizer(),
            'mse')
        model.fit(dataset, epochs=1, steps_per_epoch=1)

        weights_file = tempfile.mktemp()
        model.save_weights(weights_file)

        model_2 = keras_test_lib.get_model()
        model_2.compile(
            optimizer(),
            'mse')
        model_2.load_weights(weights_file)
        model_2.predict(
            keras_test_lib.get_predict_dataset(distribution), steps=2)
        model_2.fit(dataset, epochs=1, steps_per_epoch=1)
Esempio n. 2
0
  def test_callbacks_in_predict(self, distribution,
                                experimental_run_tf_function):
    with distribution.scope():
      model = keras_test_lib.get_model()
      model.compile(
          optimizer='sgd',
          loss='mse',
          metrics=['mae'],
          experimental_run_tf_function=experimental_run_tf_function)

    dataset = keras_test_lib.get_dataset(distribution)
    counter = Counter()

    model.predict(
        keras_test_lib.get_predict_dataset(dataset),
        steps=5,
        callbacks=[counter])

    self.assertDictEqual(
        counter.method_counts, {
            'on_predict_batch_begin': 5,
            'on_predict_batch_end': 5,
            'on_predict_begin': 1,
            'on_predict_end': 1
        })
  def test_save_load_h5(self, distribution, optimizer, cloning):
    with self.cached_session():
      dataset = keras_test_lib.get_dataset(distribution)
      with distribution.scope():
        model = keras_test_lib.get_model()
        model.compile(optimizer(), 'mse', cloning=cloning)
        model.fit(dataset, epochs=1, steps_per_epoch=1)

        weights_file = tempfile.mktemp('.h5')
        model.save_weights(weights_file)

        model_2 = keras_test_lib.get_model()
        model_2.compile(optimizer(), 'mse', cloning=cloning)
        model_2.load_weights(weights_file)
        model_2.predict(
            keras_test_lib.get_predict_dataset(distribution), steps=2)
        model_2.fit(dataset, epochs=1, steps_per_epoch=1)
Esempio n. 4
0
  def test_save_load_h5(self, distribution, optimizer, cloning):
    with self.cached_session():
      dataset = keras_test_lib.get_dataset(distribution)
      with distribution.scope():
        model = keras_test_lib.get_model()
        model.compile(optimizer(), 'mse', cloning=cloning)
        model.fit(dataset, epochs=1, steps_per_epoch=1)

        weights_file = tempfile.mktemp('.h5')
        model.save_weights(weights_file)

        model_2 = keras_test_lib.get_model()
        model_2.compile(optimizer(), 'mse', cloning=cloning)
        model_2.load_weights(weights_file)
        model_2.predict(
            keras_test_lib.get_predict_dataset(distribution), steps=2)
        model_2.fit(dataset, epochs=1, steps_per_epoch=1)
Esempio n. 5
0
  def test_save_load_h5(self, distribution):
    with self.cached_session():
      dataset = keras_test_lib.get_dataset(distribution)
      with distribution.scope():
        model = keras_test_lib.get_model()
        model.compile(rms_prop_keras.RMSprop(learning_rate=0.01), 'mse')
        model.fit(dataset, epochs=1, steps_per_epoch=1)

        weights_file = tempfile.mktemp('.h5')
        model.save_weights(weights_file)

        model_2 = keras_test_lib.get_model()
        model_2.compile(rms_prop_keras.RMSprop(learning_rate=0.01), 'mse')
        model_2.load_weights(weights_file)
        model_2.predict(
            keras_test_lib.get_predict_dataset(distribution), steps=2)
        model_2.fit(dataset, epochs=1, steps_per_epoch=1)
  def test_callbacks_in_predict(self, distribution):
    with distribution.scope():
      model = keras_test_lib.get_model()
      model.compile(optimizer='sgd', loss='mse', metrics=['mae'])

    dataset = keras_test_lib.get_dataset(distribution)
    counter = Counter()

    model.predict(
        keras_test_lib.get_predict_dataset(dataset),
        steps=5,
        callbacks=[counter])

    self.assertDictEqual(
        counter.method_counts, {
            'on_predict_batch_begin': 5,
            'on_predict_batch_end': 5,
            'on_predict_begin': 1,
            'on_predict_end': 1
        })
Esempio n. 7
0
  def test_save_load_trackable(self, distribution, optimizer, cloning):
    # TODO(b/123533246): Enable the test for TPU once bug is fixed
    if (isinstance(distribution, (tpu_strategy.TPUStrategy,
                                  tpu_strategy.TPUStrategyV1)) and
        distribution.extended.steps_per_run > 1):
      self.skipTest('MultiStep TPU Strategy deadlocks with optimizer restore.')
    with self.cached_session():
      dataset = keras_test_lib.get_dataset(distribution)
      with distribution.scope():
        model = keras_test_lib.get_model()
        model.compile(optimizer(), 'mse', cloning=cloning)
        model.fit(dataset, epochs=1, steps_per_epoch=1)

        weights_file = tempfile.mktemp()
        model.save_weights(weights_file)

        model_2 = keras_test_lib.get_model()
        model_2.compile(optimizer(), 'mse', cloning=cloning)
        model_2.load_weights(weights_file)
        model_2.predict(
            keras_test_lib.get_predict_dataset(distribution), steps=2)
        model_2.fit(dataset, epochs=1, steps_per_epoch=1)