def testSaveAndLoadSavedModelExport(
      self, model_builder, uses_learning_phase, optimizer, train_before_export):
    saved_model_path = self._save_model_dir()
    with self.session(graph=ops.Graph()):
      np.random.seed(130)
      input_arr = np.random.random((1, 3))
      target_arr = np.random.random((1, 3))

      model = model_builder(uses_learning_phase)
      if optimizer is not None:
        model.compile(
            loss='mse',
            optimizer=optimizer,
            metrics=['mae'])
        if train_before_export:
          model.train_on_batch(input_arr, target_arr)

        ref_loss, ref_mae = model.evaluate(input_arr, target_arr)

      ref_predict = model.predict(input_arr)

      # Export SavedModel
      output_path = keras_saved_model.save_keras_model(model, saved_model_path)

    input_name = model.input_names[0]
    output_name = model.output_names[0]
    target_name = output_name + '_target'

    # Load predict graph, and test predictions
    with session.Session(graph=ops.Graph()) as sess:
      inputs, outputs, _ = load_model(sess, output_path,
                                      model_fn_lib.ModeKeys.PREDICT)

      predictions = sess.run(outputs[output_name],
                             {inputs[input_name]: input_arr})
      self.assertAllClose(ref_predict, predictions, atol=1e-05)

    if optimizer:
      # Load eval graph, and test predictions, loss and metric values
      with session.Session(graph=ops.Graph()) as sess:
        inputs, outputs, _ = load_model(sess, output_path,
                                        model_fn_lib.ModeKeys.EVAL)

        # First obtain the loss and predictions, and run the metric update op by
        # feeding in the inputs and targets.
        loss, predictions, _ = sess.run(
            (outputs['loss'], outputs['predictions/' + output_name],
             outputs['metrics/mean_absolute_error/update_op']), {
                 inputs[input_name]: input_arr,
                 inputs[target_name]: target_arr
             })

        # The metric value should be run after the update op, to ensure that it
        # reflects the correct value.
        metric_value = sess.run(outputs['metrics/mean_absolute_error/value'])

        self.assertEqual(int(train_before_export),
                         sess.run(training_module.get_global_step()))
        self.assertAllClose(ref_loss, loss, atol=1e-05)
        self.assertAllClose(ref_mae, metric_value, atol=1e-05)
        self.assertAllClose(ref_predict, predictions, atol=1e-05)

      # Load train graph, and check for the train op, and prediction values
      with session.Session(graph=ops.Graph()) as sess:
        inputs, outputs, meta_graph_def = load_model(
            sess, output_path, model_fn_lib.ModeKeys.TRAIN)
        self.assertEqual(int(train_before_export),
                         sess.run(training_module.get_global_step()))
        self.assertIn('loss', outputs)
        self.assertIn('metrics/mean_absolute_error/update_op', outputs)
        self.assertIn('metrics/mean_absolute_error/value', outputs)
        self.assertIn('predictions/' + output_name, outputs)

        # Train for a step
        train_op = loader_impl.get_train_op(meta_graph_def)
        train_outputs, _ = sess.run(
            [outputs, train_op], {inputs[input_name]: input_arr,
                                  inputs[target_name]: target_arr})
        self.assertEqual(int(train_before_export) + 1,
                         sess.run(training_module.get_global_step()))

        if uses_learning_phase:
          self.assertAllClose(
              [[0, 0, 0]], train_outputs['predictions/' + output_name],
              atol=1e-05)
        else:
          self.assertNotAllClose(
              [[0, 0, 0]], train_outputs['predictions/' + output_name],
              atol=1e-05)
    def testSaveAndLoadSavedModelExport(self, model_builder,
                                        uses_learning_phase, optimizer_cls,
                                        train_before_export):
        optimizer = None if optimizer_cls is None else optimizer_cls()

        saved_model_dir = self._save_model_dir()

        np.random.seed(130)
        input_arr = np.random.random((1, 3))
        target_arr = np.random.random((1, 3))

        model = model_builder(uses_learning_phase)
        if optimizer is not None:
            model.compile(loss='mse', optimizer=optimizer, metrics=['mae'])
            if train_before_export:
                model.train_on_batch(input_arr, target_arr)

            ref_loss, ref_mae = model.evaluate(input_arr, target_arr)

        ref_predict = model.predict(input_arr)

        # Export SavedModel
        keras_saved_model.export_saved_model(model, saved_model_dir)

        input_name = model.input_names[0]
        output_name = model.output_names[0]
        target_name = output_name + '_target'

        # Load predict graph, and test predictions
        with session.Session(graph=ops.Graph()) as sess:
            inputs, outputs, _ = load_model(sess, saved_model_dir,
                                            mode_keys.ModeKeys.PREDICT)

            predictions = sess.run(outputs[output_name],
                                   {inputs[input_name]: input_arr})
            self.assertAllClose(ref_predict, predictions, atol=1e-05)

        if optimizer:
            # Load eval graph, and test predictions, loss and metric values
            with session.Session(graph=ops.Graph()) as sess:
                inputs, outputs, _ = load_model(sess, saved_model_dir,
                                                mode_keys.ModeKeys.TEST)

                # First obtain the loss and predictions, and run the metric update op by
                # feeding in the inputs and targets.
                metrics_name = 'mae' if tf2.enabled(
                ) else 'mean_absolute_error'
                metrics_update_op_key = 'metrics/' + metrics_name + '/update_op'
                metrics_value_op_key = 'metrics/' + metrics_name + '/value'

                loss, predictions, _ = sess.run(
                    (outputs['loss'], outputs['predictions/' + output_name],
                     outputs[metrics_update_op_key]), {
                         inputs[input_name]: input_arr,
                         inputs[target_name]: target_arr
                     })

                # The metric value should be run after the update op, to ensure that it
                # reflects the correct value.
                metric_value = sess.run(outputs[metrics_value_op_key])

                self.assertEqual(int(train_before_export),
                                 sess.run(training_module.get_global_step()))
                self.assertAllClose(ref_loss, loss, atol=1e-05)
                self.assertAllClose(ref_mae, metric_value, atol=1e-05)
                self.assertAllClose(ref_predict, predictions, atol=1e-05)

            # Load train graph, and check for the train op, and prediction values
            with session.Session(graph=ops.Graph()) as sess:
                inputs, outputs, meta_graph_def = load_model(
                    sess, saved_model_dir, mode_keys.ModeKeys.TRAIN)
                self.assertEqual(int(train_before_export),
                                 sess.run(training_module.get_global_step()))
                self.assertIn('loss', outputs)
                self.assertIn(metrics_update_op_key, outputs)
                self.assertIn(metrics_value_op_key, outputs)
                self.assertIn('predictions/' + output_name, outputs)

                # Train for a step
                train_op = loader_impl.get_train_op(meta_graph_def)
                train_outputs, _ = sess.run([outputs, train_op], {
                    inputs[input_name]: input_arr,
                    inputs[target_name]: target_arr
                })
                self.assertEqual(
                    int(train_before_export) + 1,
                    sess.run(training_module.get_global_step()))

                if uses_learning_phase:
                    self.assertAllClose([[0, 0, 0]],
                                        train_outputs['predictions/' +
                                                      output_name],
                                        atol=1e-05)
                else:
                    self.assertNotAllClose([[0, 0, 0]],
                                           train_outputs['predictions/' +
                                                         output_name],
                                           atol=1e-05)