Пример #1
0
    def testSentimentExampleAccuracy(self):
        raw_data_dir = os.path.join(os.path.dirname(__file__),
                                    'testdata/sentiment')
        working_dir = self.get_temp_dir()

        # Copy data from raw data directory to `working_dir`
        for filename in [
                'test_shuffled-00000-of-00001', 'train_shuffled-00000-of-00001'
        ]:
            shutil.copy(os.path.join(raw_data_dir, filename), working_dir)

        sentiment_example.transform_data(working_dir)
        results = sentiment_example.train_and_evaluate(
            working_dir, num_train_instances=1000, num_test_instances=1000)
        self.assertGreaterEqual(results['accuracy'], 0.7)

        # Delete temp directory and transform_fn directory.  This ensures that the
        # test of serving the model below will only pass if the SavedModel saved
        # to sentiment_example.EXPORTED_MODEL_DIR is hermetic, i.e does not contain
        # references to tft_temp and transform_fn.
        shutil.rmtree(
            os.path.join(working_dir, sentiment_example.TRANSFORM_TEMP_DIR))
        shutil.rmtree(
            os.path.join(working_dir, tft.TFTransformOutput.TRANSFORM_FN_DIR))

        if local_model_server.local_model_server_supported():
            model_name = 'my_model'
            model_path = os.path.join(working_dir,
                                      sentiment_example.EXPORTED_MODEL_DIR)
            with local_model_server.start_server(model_name,
                                                 model_path) as address:
                # Use made up data chosen to give high probability of negative
                # sentiment.
                ascii_classification_request = """model_spec { name: "my_model" }
input {
  example_list {
    examples {
      features {
        feature {
          key: "review"
          value: {
            bytes_list {
              value: "errible terrible terrible terrible terrible terrible terrible."
            }
          }
        }
      }
    }
  }
}"""
                results = local_model_server.make_classification_request(
                    address, ascii_classification_request)
                self.assertEqual(len(results), 1)
                self.assertEqual(len(results[0].classes), 2)
                self.assertEqual(results[0].classes[0].label, '0')
                self.assertGreater(results[0].classes[0].score, 0.8)
                self.assertEqual(results[0].classes[1].label, '1')
                self.assertLess(results[0].classes[1].score, 0.2)
Пример #2
0
    def testSentimentExampleAccuracy(self):
        raw_data_dir = os.path.join(os.path.dirname(__file__),
                                    'testdata/sentiment')
        working_dir = self.get_temp_dir()

        # Copy data from raw data directory to `working_dir`
        for filename in [
                'test_shuffled-00000-of-00001', 'train_shuffled-00000-of-00001'
        ]:
            shutil.copy(os.path.join(raw_data_dir, filename), working_dir)

        sentiment_example.transform_data(working_dir)
        results = sentiment_example.train_and_evaluate(
            working_dir, num_train_instances=1000, num_test_instances=1000)
        self.assertGreaterEqual(results['accuracy'], 0.7)

        if local_model_server.local_model_server_supported():
            model_name = 'my_model'
            model_path = os.path.join(working_dir,
                                      sentiment_example.EXPORTED_MODEL_DIR)
            with local_model_server.start_server(model_name,
                                                 model_path) as address:
                # Use made up data chosen to give high probability of negative
                # sentiment.
                ascii_classification_request = """model_spec { name: "my_model" }
input {
  example_list {
    examples {
      features {
        feature {
          key: "review"
          value: {
            bytes_list {
              value: "errible terrible terrible terrible terrible terrible terrible."
            }
          }
        }
      }
    }
  }
}"""
                results = local_model_server.make_classification_request(
                    address, ascii_classification_request)
                self.assertEqual(len(results), 1)
                self.assertEqual(len(results[0].classes), 2)
                self.assertEqual(results[0].classes[0].label, '0')
                self.assertGreater(results[0].classes[0].score, 0.8)
                self.assertEqual(results[0].classes[1].label, '1')
                self.assertLess(results[0].classes[1].score, 0.2)