コード例 #1
0
    def run_eval_metrics_correctness_test(self, distribution):
        with self.cached_session():
            self.set_up_test_config()

            model = self.get_model(distribution=distribution)

            # verify correctness of stateful and stateless metrics.
            x = np.ones((100, 4)).astype('float32')
            y = np.ones((100, 1)).astype('float32')
            dataset = tf.data.Dataset.from_tensor_slices((x, y)).repeat()
            dataset = keras_correctness_test_base.batch_wrapper(dataset, 4)
            outs = model.evaluate(dataset, steps=10)
            self.assertEqual(outs[1], 1.)
            self.assertEqual(outs[2], 1.)

            y = np.zeros((100, 1)).astype('float32')
            dataset = tf.data.Dataset.from_tensor_slices((x, y)).repeat()
            dataset = keras_correctness_test_base.batch_wrapper(dataset, 4)
            outs = model.evaluate(dataset, steps=10)
            self.assertEqual(outs[1], 0.)
            self.assertEqual(outs[2], 0.)
コード例 #2
0
    def run_metric_correctness_test(self, distribution):
        with self.cached_session():
            self.set_up_test_config()

            x_train, y_train, _ = self.get_data()
            model = self.get_model(distribution=distribution)

            batch_size = 64
            batch_size = (keras_correctness_test_base.get_batch_size(
                batch_size, distribution))
            train_dataset = tf.data.Dataset.from_tensor_slices(
                (x_train, y_train))
            train_dataset = (keras_correctness_test_base.batch_wrapper(
                train_dataset, batch_size))

            history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10)
            self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0])