コード例 #1
0
 def test_network_infer_lib(self):
     driver = infer_lib.ServingDriver('efficientdet-d0',
                                      self.tmp_path,
                                      only_network=True)
     images = tf.ones((1, 512, 512, 3))
     class_outputs, box_outputs = driver.serve(images)
     self.assertLen(class_outputs, 5)
     self.assertLen(box_outputs, 5)
コード例 #2
0
 def test_export(self):
     saved_model_path = os.path.join(self.tmp_path, 'saved_model')
     driver = infer_lib.ServingDriver('efficientdet-d0', self.tmp_path)
     driver.export(saved_model_path)
     has_saved_model = tf.saved_model.contains_saved_model(saved_model_path)
     self.assertAllEqual(has_saved_model, True)
     driver.load(saved_model_path)
     driver.load(os.path.join(saved_model_path,
                              'efficientdet-d0_frozen.pb'))
コード例 #3
0
 def test_infer_lib(self):
     driver = infer_lib.ServingDriver('efficientdet-d0', self.tmp_path)
     images = tf.ones((1, 512, 512, 3))
     boxes, scores, classes, valid_lens = driver.serve(images)
     self.assertEqual(tf.reduce_mean(boxes), 163.09)
     self.assertEqual(tf.reduce_mean(scores), 0.01000005)
     self.assertEqual(tf.reduce_mean(classes), 1)
     self.assertEqual(tf.reduce_mean(valid_lens), 100)
     self.assertEqual(boxes.shape, (1, 100, 4))
     self.assertEqual(scores.shape, (1, 100))
     self.assertEqual(classes.shape, (1, 100))
     self.assertEqual(valid_lens.shape, (1, ))
コード例 #4
0
 def test_export_tflite_with_post_processing(self):
     saved_model_path = os.path.join(self.lite_tmp_path, 'saved_model')
     driver = infer_lib.ServingDriver('efficientdet-lite0',
                                      self.lite_tmp_path,
                                      only_network=False)
     driver.export(saved_model_path, tflite='FP32')
     self.assertTrue(
         tf.io.gfile.exists(os.path.join(saved_model_path, 'fp32.tflite')))
     tf.io.gfile.rmtree(saved_model_path)
     tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir())
     driver.export(saved_model_path,
                   tflite='INT8',
                   file_pattern=[tfrecord_path],
                   num_calibration_steps=1)
     self.assertTrue(
         tf.io.gfile.exists(os.path.join(saved_model_path, 'int8.tflite')))
コード例 #5
0
 def test_infer_lib_mixed_precision(self):
     driver = infer_lib.ServingDriver('efficientdet-d0', self.tmp_path)
     driver.build({'mixed_precision': True})
     images = tf.ones((1, 512, 512, 3))
     boxes, scores, classes, valid_lens = driver.serve(images)
     policy = tf.keras.mixed_precision.experimental.global_policy()
     if policy.name == 'float32':
         self.assertEqual(tf.reduce_mean(boxes), 163.09)
         self.assertEqual(tf.reduce_mean(scores), 0.01000005)
         self.assertEqual(tf.reduce_mean(classes), 1)
         self.assertEqual(tf.reduce_mean(valid_lens), 100)
     elif policy.name == 'float16':
         pass
     elif policy.name == 'bfloat16':
         pass
     self.assertEqual(boxes.shape, (1, 100, 4))
     self.assertEqual(scores.shape, (1, 100))
     self.assertEqual(classes.shape, (1, 100))
     self.assertEqual(valid_lens.shape, (1, ))
コード例 #6
0
ファイル: inspector.py プロジェクト: wanghao-coins/automl
def main(_):
  tf.config.run_functions_eagerly(FLAGS.debug)
  devices = tf.config.list_physical_devices('GPU')
  for device in devices:
    tf.config.experimental.set_memory_growth(device, True)

  model_config = hparams_config.get_detection_config(FLAGS.model_name)
  model_config.override(FLAGS.hparams)  # Add custom overrides
  model_config.is_training_bn = False
  if FLAGS.image_size != -1:
    model_config.image_size = FLAGS.image_size
  model_config.image_size = utils.parse_image_size(model_config.image_size)

  model_params = model_config.as_dict()
  ckpt_path_or_file = FLAGS.model_dir
  if tf.io.gfile.isdir(ckpt_path_or_file):
    ckpt_path_or_file = tf.train.latest_checkpoint(ckpt_path_or_file)
  driver = infer_lib.ServingDriver(FLAGS.model_name, ckpt_path_or_file,
                                   FLAGS.batch_size or None,
                                   FLAGS.only_network, model_params)
  if FLAGS.mode == 'export':
    if not FLAGS.saved_model_dir:
      raise ValueError('Please specify --saved_model_dir=')
    model_dir = FLAGS.saved_model_dir
    if tf.io.gfile.exists(model_dir):
      tf.io.gfile.rmtree(model_dir)
    driver.export(model_dir, FLAGS.tensorrt, FLAGS.tflite, FLAGS.file_pattern,
                  FLAGS.num_calibration_steps)
    print('Model are exported to %s' % model_dir)
  elif FLAGS.mode == 'infer':
    image_file = tf.io.read_file(FLAGS.input_image)
    image_arrays = tf.io.decode_image(image_file)
    image_arrays.set_shape((None, None, 3))
    image_arrays = tf.expand_dims(image_arrays, axis=0)
    if FLAGS.saved_model_dir:
      driver.load(FLAGS.saved_model_dir)
      if FLAGS.saved_model_dir.endswith('.tflite'):
        image_size = utils.parse_image_size(model_config.image_size)
        image_arrays = tf.image.resize_with_pad(image_arrays, *image_size)
        image_arrays = tf.cast(image_arrays, tf.uint8)
    detections_bs = driver.serve(image_arrays)
    boxes, scores, classes, _ = tf.nest.map_structure(np.array, detections_bs)
    raw_image = Image.fromarray(np.array(image_arrays)[0])
    img = driver.visualize(
        raw_image,
        boxes[0],
        classes[0],
        scores[0],
        min_score_thresh=model_config.nms_configs.score_thresh or 0.4,
        max_boxes_to_draw=model_config.nms_configs.max_output_size)
    output_image_path = os.path.join(FLAGS.output_image_dir, '0.jpg')
    Image.fromarray(img).save(output_image_path)
    print('writing file to %s' % output_image_path)
  elif FLAGS.mode == 'benchmark':
    if FLAGS.saved_model_dir:
      driver.load(FLAGS.saved_model_dir)

    batch_size = FLAGS.batch_size or 1
    if FLAGS.input_image:
      image_file = tf.io.read_file(FLAGS.input_image)
      image_arrays = tf.image.decode_image(image_file)
      image_arrays.set_shape((None, None, 3))
      image_arrays = tf.expand_dims(image_arrays, 0)
      if batch_size > 1:
        image_arrays = tf.tile(image_arrays, [batch_size, 1, 1, 1])
    else:
      # use synthetic data if no image is provided.
      image_arrays = tf.ones((batch_size, *model_config.image_size, 3),
                             dtype=tf.uint8)
    if FLAGS.only_network:
      image_arrays = tf.image.convert_image_dtype(image_arrays, tf.float32)
      image_arrays = tf.image.resize(image_arrays, model_config.image_size)
    driver.benchmark(image_arrays, FLAGS.bm_runs, FLAGS.trace_filename)
  elif FLAGS.mode == 'dry':
    # transfer to tf2 format ckpt
    driver.build()
    if FLAGS.export_ckpt:
      driver.model.save_weights(FLAGS.export_ckpt)
  elif FLAGS.mode == 'video':
    import cv2  # pylint: disable=g-import-not-at-top
    if FLAGS.saved_model_dir:
      driver.load(FLAGS.saved_model_dir)
    cap = cv2.VideoCapture(FLAGS.input_video)
    if not cap.isOpened():
      print('Error opening input video: {}'.format(FLAGS.input_video))

    out_ptr = None
    if FLAGS.output_video:
      frame_width, frame_height = int(cap.get(3)), int(cap.get(4))
      out_ptr = cv2.VideoWriter(FLAGS.output_video,
                                cv2.VideoWriter_fourcc('m', 'p', '4', 'v'),
                                cap.get(5), (frame_width, frame_height))

    while cap.isOpened():
      # Capture frame-by-frame
      ret, frame = cap.read()
      if not ret:
        break

      raw_frames = np.array([frame])
      detections_bs = driver.serve(raw_frames)
      boxes, scores, classes, _ = tf.nest.map_structure(np.array, detections_bs)
      new_frame = driver.visualize(
          raw_frames[0],
          boxes[0],
          classes[0],
          scores[0],
          min_score_thresh=model_config.nms_configs.score_thresh or 0.4,
          max_boxes_to_draw=model_config.nms_configs.max_output_size)

      if out_ptr:
        # write frame into output file.
        out_ptr.write(new_frame)
      else:
        # show the frame online, mainly used for real-time speed test.
        cv2.imshow('Frame', new_frame)
        # Press Q on keyboard to  exit
        if cv2.waitKey(1) & 0xFF == ord('q'):
          break