예제 #1
0
    def _generate_metadata(self,
                           key,
                           download_output,
                           video_path_format_string=None):
        """For each row in the annotation CSV, generates the corresponding metadata.

    Args:
      key: which split to process.
      download_output: the tuple output of _download_data containing
        - annotations_files: dict of keys to CSV annotation paths.
        - label_map: dict mapping from label strings to numeric indices.
      video_path_format_string: The format string for the path to local files.
    Yields:
      Each tf.SequenceExample of metadata, ready to pass to MediaPipe.
    """
        annotations_files, label_map = download_output
        with open(annotations_files[key], "r") as annotations:
            reader = csv.reader(annotations)
            for i, csv_row in enumerate(reader):
                if i == 0:  # the first row is the header
                    continue
                # rename the row with a constitent set of names.
                if len(csv_row) == 5:
                    row = dict(
                        list(
                            zip([
                                "label_name", "video", "start", "end", "split"
                            ], csv_row)))
                else:
                    row = dict(
                        list(zip(["video", "start", "end", "split"], csv_row)))
                metadata = tf.train.SequenceExample()
                ms.set_example_id(bytes23(row["video"] + "_" + row["start"]),
                                  metadata)
                ms.set_clip_media_id(bytes23(row["video"]), metadata)
                ms.set_clip_alternative_media_id(bytes23(row["split"]),
                                                 metadata)
                if video_path_format_string:
                    filepath = video_path_format_string.format(**row)
                    ms.set_clip_data_path(bytes23(filepath), metadata)
                assert row["start"].isdigit(), "Invalid row: %s" % str(row)
                assert row["end"].isdigit(), "Invalid row: %s" % str(row)
                if "label_name" in row:
                    ms.set_clip_label_string([bytes23(row["label_name"])],
                                             metadata)
                    if label_map:
                        ms.set_clip_label_index([label_map[row["label_name"]]],
                                                metadata)
                yield metadata
 def test_expected_functions_are_defined(self):
     # The code from media_sequence_util is already tested, but this test ensures
     # that we actually generate the expected methods. We only test one per
     # feature and the only test is to not crash with undefined attributes. By
     # passing in a value, we also ensure that the types are correct because the
     # underlying code crashes with a type mismatch.
     example = tf.train.SequenceExample()
     # context
     ms.set_example_id(b"string", example)
     ms.set_example_dataset_name(b"string", example)
     ms.set_clip_media_id(b"string", example)
     ms.set_clip_alternative_media_id(b"string", example)
     ms.set_clip_encoded_media_bytes(b"string", example)
     ms.set_clip_encoded_media_start_timestamp(47, example)
     ms.set_clip_data_path(b"string", example)
     ms.set_clip_start_timestamp(47, example)
     ms.set_clip_end_timestamp(47, example)
     ms.set_clip_label_string((b"string", b"test"), example)
     ms.set_clip_label_index((47, 49), example)
     ms.set_clip_label_confidence((0.47, 0.49), example)
     ms.set_segment_start_timestamp((47, 49), example)
     ms.set_segment_start_index((47, 49), example)
     ms.set_segment_end_timestamp((47, 49), example)
     ms.set_segment_end_index((47, 49), example)
     ms.set_segment_label_index((47, 49), example)
     ms.set_segment_label_string((b"test", b"strings"), example)
     ms.set_segment_label_confidence((0.47, 0.49), example)
     ms.set_image_format(b"test", example)
     ms.set_image_channels(47, example)
     ms.set_image_colorspace(b"test", example)
     ms.set_image_height(47, example)
     ms.set_image_width(47, example)
     ms.set_image_frame_rate(0.47, example)
     ms.set_image_data_path(b"test", example)
     ms.set_forward_flow_format(b"test", example)
     ms.set_forward_flow_channels(47, example)
     ms.set_forward_flow_colorspace(b"test", example)
     ms.set_forward_flow_height(47, example)
     ms.set_forward_flow_width(47, example)
     ms.set_forward_flow_frame_rate(0.47, example)
     ms.set_class_segmentation_format(b"test", example)
     ms.set_class_segmentation_height(47, example)
     ms.set_class_segmentation_width(47, example)
     ms.set_class_segmentation_class_label_string((b"test", b"strings"),
                                                  example)
     ms.set_class_segmentation_class_label_index((47, 49), example)
     ms.set_instance_segmentation_format(b"test", example)
     ms.set_instance_segmentation_height(47, example)
     ms.set_instance_segmentation_width(47, example)
     ms.set_instance_segmentation_object_class_index((47, 49), example)
     ms.set_bbox_parts((b"HEAD", b"TOE"), example)
     # feature lists
     ms.add_image_encoded(b"test", example)
     ms.add_image_multi_encoded([b"test", b"test"], example)
     ms.add_image_timestamp(47, example)
     ms.add_forward_flow_encoded(b"test", example)
     ms.add_forward_flow_multi_encoded([b"test", b"test"], example)
     ms.add_forward_flow_timestamp(47, example)
     ms.add_bbox_ymin((0.47, 0.49), example)
     ms.add_bbox_xmin((0.47, 0.49), example)
     ms.add_bbox_ymax((0.47, 0.49), example)
     ms.add_bbox_xmax((0.47, 0.49), example)
     ms.add_bbox_point_x((0.47, 0.49), example)
     ms.add_bbox_point_y((0.47, 0.49), example)
     ms.add_predicted_bbox_ymin((0.47, 0.49), example)
     ms.add_predicted_bbox_xmin((0.47, 0.49), example)
     ms.add_predicted_bbox_ymax((0.47, 0.49), example)
     ms.add_predicted_bbox_xmax((0.47, 0.49), example)
     ms.add_bbox_num_regions(47, example)
     ms.add_bbox_is_annotated(47, example)
     ms.add_bbox_is_generated((47, 49), example)
     ms.add_bbox_is_occluded((47, 49), example)
     ms.add_bbox_label_index((47, 49), example)
     ms.add_bbox_label_string((b"test", b"strings"), example)
     ms.add_bbox_label_confidence((0.47, 0.49), example)
     ms.add_bbox_class_index((47, 49), example)
     ms.add_bbox_class_string((b"test", b"strings"), example)
     ms.add_bbox_class_confidence((0.47, 0.49), example)
     ms.add_bbox_track_index((47, 49), example)
     ms.add_bbox_track_string((b"test", b"strings"), example)
     ms.add_bbox_track_confidence((0.47, 0.49), example)
     ms.add_bbox_timestamp(47, example)
     ms.add_predicted_bbox_class_index((47, 49), example)
     ms.add_predicted_bbox_class_string((b"test", b"strings"), example)
     ms.add_predicted_bbox_timestamp(47, example)
     ms.add_class_segmentation_encoded(b"test", example)
     ms.add_class_segmentation_multi_encoded([b"test", b"test"], example)
     ms.add_instance_segmentation_encoded(b"test", example)
     ms.add_instance_segmentation_multi_encoded([b"test", b"test"], example)
     ms.add_class_segmentation_timestamp(47, example)
     ms.set_bbox_embedding_dimensions_per_region((47, 49), example)
     ms.set_bbox_embedding_format(b"test", example)
     ms.add_bbox_embedding_floats((0.47, 0.49), example)
     ms.add_bbox_embedding_encoded((b"text", b"stings"), example)
     ms.add_bbox_embedding_confidence((0.47, 0.49), example)
예제 #3
0
  def generate_examples(self, path_to_mediapipe_binary,
                        path_to_graph_directory):
    """Downloads data and generates sharded TFRecords.

    Downloads the data files, generates metadata, and processes the metadata
    with MediaPipe to produce tf.SequenceExamples for training. The resulting
    files can be read with as_dataset(). After running this function the
    original data files can be deleted.

    Args:
      path_to_mediapipe_binary: Path to the compiled binary for the BUILD target
        mediapipe/examples/desktop/demo:media_sequence_demo.
      path_to_graph_directory: Path to the directory with MediaPipe graphs in
        mediapipe/graphs/media_sequence/.
    """
    if not path_to_mediapipe_binary:
      raise ValueError("You must supply the path to the MediaPipe binary for "
                       "mediapipe/examples/desktop/demo:media_sequence_demo.")
    if not path_to_graph_directory:
      raise ValueError(
          "You must supply the path to the directory with MediaPipe graphs in "
          "mediapipe/graphs/media_sequence/.")
    logging.info("Downloading data.")
    tf.io.gfile.makedirs(self.path_to_data)
    if sys.version_info >= (3, 0):
      urlretrieve = urllib.request.urlretrieve
    else:
      urlretrieve = urllib.request.urlretrieve
    for split in SPLITS:
      reader = csv.DictReader(SPLITS[split].split("\n"))
      all_metadata = []
      for row in reader:
        url = row["url"]
        basename = url.split("/")[-1]
        local_path = os.path.join(self.path_to_data, basename)
        if not tf.io.gfile.exists(local_path):
          urlretrieve(url, local_path)

        for start_time in range(0, int(row["duration"]), SECONDS_PER_EXAMPLE):
          metadata = tf.train.SequenceExample()
          ms.set_example_id(bytes23(basename + "_" + str(start_time)),
                            metadata)
          ms.set_clip_data_path(bytes23(local_path), metadata)
          ms.set_clip_start_timestamp(start_time * MICROSECONDS_PER_SECOND,
                                      metadata)
          ms.set_clip_end_timestamp(
              (start_time + SECONDS_PER_EXAMPLE) * MICROSECONDS_PER_SECOND,
              metadata)
          ms.set_clip_label_index((int(row["label index"]),), metadata)
          ms.set_clip_label_string((bytes23(row["label string"]),),
                                   metadata)
          all_metadata.append(metadata)
      random.seed(47)
      random.shuffle(all_metadata)
      shard_names = [self._indexed_shard(split, i) for i in range(NUM_SHARDS)]
      writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names]
      with _close_on_exit(writers) as writers:
        for i, seq_ex in enumerate(all_metadata):
          for graph in GRAPHS:
            graph_path = os.path.join(path_to_graph_directory, graph)
            seq_ex = self._run_mediapipe(path_to_mediapipe_binary, seq_ex,
                                         graph_path)
          writers[i % len(writers)].write(seq_ex.SerializeToString())