예제 #1
0
class SavedModelKerasModelTest(test_base.TestSavedModelBase):

  def setUp(self):
    self._root_dir = 'saved_model_save_load'
    super(SavedModelKerasModelTest, self).setUp()

  def _save_model(self, model, saved_dir):
    saved_model.save(model, saved_dir)

  def _load_and_run_model(self,
                          distribution,
                          saved_dir,
                          predict_dataset,
                          output_name='output_1'):
    return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
                                                       predict_dataset,
                                                       output_name)

  @combinations.generate(test_base.simple_models_with_strategies())
  def test_save_no_strategy_restore_strategy(self, model_and_input,
                                             distribution):
    self.run_test_save_no_strategy_restore_strategy(
        model_and_input, distribution)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategies(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_no_strategy(self, model_and_input,
                                             distribution, save_in_scope):
    self.run_test_save_strategy_restore_no_strategy(
        model_and_input, distribution, save_in_scope)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategy_pairs(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_strategy(self, model_and_input,
                                          distribution_for_saving,
                                          distribution_for_restoring,
                                          save_in_scope):
    self.run_test_save_strategy_restore_strategy(model_and_input,
                                                 distribution_for_saving,
                                                 distribution_for_restoring,
                                                 save_in_scope)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategies(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_no_variable_device_placement(self, model_and_input, distribution,
                                        save_in_scope):
    saved_dir = self.run_test_save_strategy(model_and_input, distribution,
                                            save_in_scope)
    func = saved_model.load(saved_dir)
    concrete_function = func.signatures[test_base._DEFAULT_FUNCTION_KEY]
    for f in concrete_function.graph.as_graph_def().library.function:
      for n in f.node_def:
        if n.op == 'ReadVariableOp':
          self.assertEmpty(n.device)
class KerasSaveLoadTest(test_base.TestSavedModelBase):

  def setUp(self):
    self._root_dir = 'keras_save_load'
    super(KerasSaveLoadTest, self).setUp()

  def _save_model(self, model, saved_dir):
    model.save(saved_dir, save_format='tf')

  def _load_and_run_model(self,
                          distribution,
                          saved_dir,
                          predict_dataset,
                          output_name='output_1'):
    restored_keras_model = save.load_model(saved_dir)
    return restored_keras_model.predict(
        predict_dataset, steps=test_base.PREDICT_STEPS)

  @combinations.generate(test_base.simple_models_with_strategies())
  def test_save_no_strategy_restore_strategy(self, model_and_input,
                                             distribution):
    self.run_test_save_no_strategy_restore_strategy(
        model_and_input, distribution)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategies(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_no_strategy(self, model_and_input,
                                             distribution, save_in_scope):
    self.run_test_save_strategy_restore_no_strategy(
        model_and_input, distribution, save_in_scope)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategy_pairs(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_strategy(self, model_and_input,
                                          distribution_for_saving,
                                          distribution_for_restoring,
                                          save_in_scope):
    self.run_test_save_strategy_restore_strategy(model_and_input,
                                                 distribution_for_saving,
                                                 distribution_for_restoring,
                                                 save_in_scope)
예제 #3
0
class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
    def setUp(self):
        self._root_dir = 'saved_model_save_load'
        super(SavedModelSaveAndLoadTest, self).setUp()

    def _save_model(self, model, saved_dir):
        save.save_model(model, saved_dir, save_format='tf')

    def _load_and_run_model(self,
                            distribution,
                            saved_dir,
                            predict_dataset,
                            output_name='output_1'):
        return test_base.load_and_run_with_saved_model_api(
            distribution, saved_dir, predict_dataset, output_name)

    @ds_combinations.generate(test_base.simple_models_with_strategies())
    def test_save_no_strategy_restore_strategy(self, model_and_input,
                                               distribution):
        self.run_test_save_no_strategy_restore_strategy(
            model_and_input, distribution)

    @ds_combinations.generate(
        combinations.times(test_base.simple_models_with_strategies(),
                           combinations.combine(save_in_scope=[True, False])))
    def test_save_strategy_restore_no_strategy(self, model_and_input,
                                               distribution, save_in_scope):
        self.run_test_save_strategy_restore_no_strategy(
            model_and_input, distribution, save_in_scope)

    @ds_combinations.generate(
        combinations.times(test_base.simple_models_with_strategy_pairs(),
                           combinations.combine(save_in_scope=[True, False])))
    def test_save_strategy_restore_strategy(self, model_and_input,
                                            distribution_for_saving,
                                            distribution_for_restoring,
                                            save_in_scope):
        self.run_test_save_strategy_restore_strategy(
            model_and_input, distribution_for_saving,
            distribution_for_restoring, save_in_scope)