Exemplo n.º 1
0
  def test_max_results_option(self):
    """Test the max_results option."""
    max_results = 3
    option = AudioClassifierOptions(max_results=max_results)
    classifier = AudioClassifier(_MODEL_FILE, options=option)
    categories = classifier.classify(self._input_tensor)

    self.assertLessEqual(
        len(categories), max_results, 'Too many results returned.')
Exemplo n.º 2
0
  def test_deny_list(self):
    """Test the label_deny_list option."""
    deny_list = ['Animal']
    option = AudioClassifierOptions(label_deny_list=deny_list)
    classifier = AudioClassifier(_MODEL_FILE, options=option)
    categories = classifier.classify(self._input_tensor)

    for category in categories:
      label = category.label
      self.assertNotIn(label, deny_list,
                       'Label "{0}" found but in deny list.'.format(label))
Exemplo n.º 3
0
  def test_allow_list(self):
    """Test the label_allow_list option."""
    allow_list = ['Cat']
    option = AudioClassifierOptions(label_allow_list=allow_list)
    classifier = AudioClassifier(_MODEL_FILE, options=option)
    categories = classifier.classify(self._input_tensor)

    for category in categories:
      label = category.label
      self.assertIn(
          label, allow_list,
          'Label "{0}" found but not in label allow list'.format(label))
Exemplo n.º 4
0
  def test_score_threshold_option(self):
    """Test the score_threshold option."""
    score_threshold = 0.5
    option = AudioClassifierOptions(score_threshold=score_threshold)
    classifier = AudioClassifier(_MODEL_FILE, options=option)
    categories = classifier.classify(self._input_tensor)

    for category in categories:
      score = category.score
      self.assertGreaterEqual(
          score, score_threshold,
          'Classification with score lower than threshold found. {0}'.format(
              category))
Exemplo n.º 5
0
  def test_default_option(self):
    """Check if the default option works correctly."""
    classifier = AudioClassifier(_MODEL_FILE)
    categories = classifier.classify(self._input_tensor)

    # Check if all ground truth classification is found.
    for gt_classification in self._ground_truth_classifications:
      is_gt_found = False
      for real_classification in categories:
        is_label_match = real_classification.label == gt_classification.label
        is_score_match = abs(real_classification.score -
                             gt_classification.score) < _ACCEPTABLE_ERROR_RANGE

        # If a matching classification is found, stop the loop.
        if is_label_match and is_score_match:
          is_gt_found = True
          break

      # If no matching classification found, fail the test.
      self.assertTrue(is_gt_found, '{0} not found.'.format(gt_classification))
Exemplo n.º 6
0
  def _create_ground_truth_csv(self, output_file=_GROUND_TRUTH_FILE):
    """A util function to regenerate the ground truth result.

    This function is not used in the test but it exists to make adding more
    audio and ground truth data to the test easier in the future.

    Args:
      output_file: Filename to write the ground truth CSV.
    """
    classifier = AudioClassifier(_MODEL_FILE)
    categories = classifier.classify(self._input_tensor)
    with open(output_file, 'w') as f:
      header = ['label', 'score']
      writer = csv.DictWriter(f, fieldnames=header)
      writer.writeheader()
      for category in categories:
        writer.writerow({
            'label': category.label,
            'score': category.score,
        })