コード例 #1
0
  def test_max_results_option(self):
    """Test the max_results option."""
    option = ImageClassifierOptions(max_results=_MAX_RESULTS)
    classifier = ImageClassifier(_MODEL_FILE, options=option)
    categories = classifier.classify(self.image)

    self.assertLessEqual(
        len(categories), _MAX_RESULTS, 'Too many results returned.')
コード例 #2
0
  def test_deny_list(self):
    """Test the label_deny_list option."""
    option = ImageClassifierOptions(label_deny_list=_DENY_LIST)
    classifier = ImageClassifier(_MODEL_FILE, options=option)
    categories = classifier.classify(self.image)

    for category in categories:
      label = category.label
      self.assertNotIn(label, _DENY_LIST,
                       'Label "{0}" found but in deny list.'.format(label))
コード例 #3
0
  def test_allow_list(self):
    """Test the label_allow_list option."""
    option = ImageClassifierOptions(label_allow_list=_ALLOW_LIST)
    classifier = ImageClassifier(_MODEL_FILE, options=option)
    categories = classifier.classify(self.image)

    for category in categories:
      label = category.label
      self.assertIn(
          label, _ALLOW_LIST,
          'Label "{0}" found but not in label allow list'.format(label))
コード例 #4
0
  def test_score_threshold_option(self):
    """Test the score_threshold option."""
    option = ImageClassifierOptions(score_threshold=_SCORE_THRESHOLD)
    classifier = ImageClassifier(_MODEL_FILE, options=option)
    categories = classifier.classify(self.image)

    for category in categories:
      score = category.score
      self.assertGreaterEqual(
          score, _SCORE_THRESHOLD,
          'Classification with score lower than threshold found. {0}'.format(
              category))
コード例 #5
0
  def test_default_option(self):
    """Check if the default option works correctly."""
    classifier = ImageClassifier(_MODEL_FILE)
    categories = classifier.classify(self.image)

    # Check if all ground truth classification is found.
    for gt_classification in self._ground_truth_classifications:
      is_gt_found = False
      for real_classification in categories:
        is_label_match = real_classification.label == gt_classification.label
        is_score_match = abs(real_classification.score -
                             gt_classification.score) < _ACCEPTABLE_ERROR_RANGE

        # If a matching classification is found, stop the loop.
        if is_label_match and is_score_match:
          is_gt_found = True
          break

      # If no matching classification found, fail the test.
      self.assertTrue(is_gt_found, '{0} not found.'.format(gt_classification))
コード例 #6
0
  def _create_ground_truth_csv(self, output_file=_GROUND_TRUTH_FILE):
    """A util function to regenerate the ground truth result.

    This function is not used in the test but it exists to make adding more
    images and ground truth data to the test easier in the future.

    Args:
      output_file: Filename to write the ground truth CSV.
    """
    classifier = ImageClassifier(_MODEL_FILE)
    categories = classifier.classify(self.image)
    with open(output_file, 'w') as f:
      header = ['label', 'score']
      writer = csv.DictWriter(f, fieldnames=header)
      writer.writeheader()
      for category in categories:
        writer.writerow({
            'label': category.label,
            'score': category.score,
        })
コード例 #7
0
def run(model: str, max_results: int, num_threads: int, enable_edgetpu: bool,
        camera_id: int, width: int, height: int) -> None:
    """Continuously run inference on images acquired from the camera.

  Args:
      model: Name of the TFLite image classification model.
      max_results: Max of classification results.
      num_threads: Number of CPU threads to run the model.
      enable_edgetpu: Whether to run the model on EdgeTPU.
      camera_id: The camera id to be passed to OpenCV.
      width: The width of the frame captured from the camera.
      height: The height of the frame captured from the camera.
  """

    # Initialize the image classification model
    options = ImageClassifierOptions(num_threads=num_threads,
                                     max_results=max_results,
                                     enable_edgetpu=enable_edgetpu)
    classifier = ImageClassifier(model, options)

    # Variables to calculate FPS
    counter, fps = 0, 0
    start_time = time.time()

    # Start capturing video input from the camera
    cap = cv2.VideoCapture(camera_id)
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)

    # Continuously capture images from the camera and run inference
    while cap.isOpened():
        success, image = cap.read()
        if not success:
            sys.exit(
                'ERROR: Unable to read from webcam. Please verify your webcam settings.'
            )

        counter += 1
        image = cv2.flip(image, 1)
        # List classification results
        categories = classifier.classify(image)
        # Show classification results on the image
        for idx, category in enumerate(categories):
            class_name = category.label
            score = round(category.score, 2)
            result_text = class_name + ' (' + str(score) + ')'
            text_location = (_LEFT_MARGIN, (idx + 2) * _ROW_SIZE)
            cv2.putText(image, result_text, text_location,
                        cv2.FONT_HERSHEY_PLAIN, _FONT_SIZE, _TEXT_COLOR,
                        _FONT_THICKNESS)

        # Calculate the FPS
        if counter % _FPS_AVERAGE_FRAME_COUNT == 0:
            end_time = time.time()
            fps = _FPS_AVERAGE_FRAME_COUNT / (end_time - start_time)
            start_time = time.time()

        # Show the FPS
        fps_text = 'FPS = ' + str(int(fps))
        text_location = (_LEFT_MARGIN, _ROW_SIZE)
        cv2.putText(image, fps_text, text_location, cv2.FONT_HERSHEY_PLAIN,
                    _FONT_SIZE, _TEXT_COLOR, _FONT_THICKNESS)

        # Stop the program if the ESC key is pressed.
        if cv2.waitKey(1) == 27:
            break
        cv2.imshow('image_classification', image)

    cap.release()
    cv2.destroyAllWindows()
コード例 #8
0
    # perform classification
    elif const.TST_MODE in ARGS.mode:

        if ARGS.model_dir is None:
            ARGS.model_dir = os.path.join(const.MODELS_FOLDER, ARGS.arch)

        elif not os.path.isdir(ARGS.model_dir):
            ARG_PARSER.error("Model checkpoint directory does not exist!")

        # wildcards should always follow arrangement in ut.save_model
        # filename format arrangement: date, arch, alias, epoch
        if ARGS.model_epoch is None:
            # get all checkpoint files
            MDL_WILDCARD = os.path.join(ARGS.model_dir,
                                        '*_' + ARGS.alias + '*.mdl.data*')

        else:
            MDL_WILDCARD = os.path.join(
                ARGS.model_dir,
                str(ARGS.model_epoch) + '*_' + ARGS.alias + '*.mdl.data*')

        CHKPT_PATH = glob.glob(MDL_WILDCARD)
        if not CHKPT_PATH:
            ARG_PARSER.error("Checkpoint file does not exist!")

        # purpose of reverse is to get MAX step in case ARGS.model_epoch is not specified
        # sort has no effect if ARGS.model_epoch is specified
        CHKPT_PATH.sort(reverse=True)
        IMG_MODEL.classify(TF_REC, CHKPT_PATH[0].split('.data')[0], ARGS.alias,
                           NUM_DATA)