def run_eval_metrics_correctness_test(self, distribution,
                                          experimental_run_tf_function):
        with self.cached_session():
            self.set_up_test_config()
            self.skip_unsupported_test_configuration(
                distribution, experimental_run_tf_function)

            model = self.get_model(experimental_run_tf_function,
                                   distribution=distribution)

            # verify correctness of stateful and stateless metrics.
            x = np.ones((100, 4)).astype('float32')
            y = np.ones((100, 1)).astype('float32')
            dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat()
            dataset = keras_correctness_test_base.batch_wrapper(dataset, 4)
            outs = model.evaluate(dataset, steps=10)
            self.assertEqual(outs[1], 1.)
            self.assertEqual(outs[2], 1.)

            y = np.zeros((100, 1)).astype('float32')
            dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat()
            dataset = keras_correctness_test_base.batch_wrapper(dataset, 4)
            outs = model.evaluate(dataset, steps=10)
            self.assertEqual(outs[1], 0.)
            self.assertEqual(outs[2], 0.)
  def run_eval_metrics_correctness_test(self, distribution, cloning):
    with self.cached_session():
      self.set_up_test_config()
      self.skip_unsupported_test_configuration(distribution)

      model = self.get_model(cloning, distribution=distribution)

      # verify correctness of stateful and stateless metrics.
      x = np.ones((100, 4)).astype('float32')
      y = np.ones((100, 1)).astype('float32')
      dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat()
      dataset = keras_correctness_test_base.batch_wrapper(dataset, 4)
      outs = model.evaluate(dataset, steps=10)
      self.assertEqual(outs[1], 1.)
      self.assertEqual(outs[2], 1.)

      y = np.zeros((100, 1)).astype('float32')
      dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat()
      dataset = keras_correctness_test_base.batch_wrapper(dataset, 4)
      outs = model.evaluate(dataset, steps=10)
      self.assertEqual(outs[1], 0.)
      self.assertEqual(outs[2], 0.)
Пример #3
0
    def run_metric_correctness_test(self, distribution):
        with self.cached_session():
            self.set_up_test_config()

            x_train, y_train, _ = self.get_data()
            model = self.get_model(distribution=distribution)

            batch_size = 64
            batch_size = (keras_correctness_test_base.get_batch_size(
                batch_size, distribution))
            train_dataset = dataset_ops.Dataset.from_tensor_slices(
                (x_train, y_train))
            train_dataset = (keras_correctness_test_base.batch_wrapper(
                train_dataset, batch_size))

            history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10)
            self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0])