示例#1
0
def test_png_img_writer():
    base_dir = get_temp_path(prefix="test_png_img_writer")
    manager = CacheManager(base_dir)

    tagged_data = ImageMockTaggedData(relative_path="img.png")

    assert manager.get(tagged_data.tag) is None

    ds_object = DataSetObject(tagged_data)
    ds_object.add_cache_manager(manager)
    object_bytes = ds_object.get()  # This load the object into the cache
    assert manager.get(ds_object.tag) == object_bytes
示例#2
0
def test_classification_output():
    img = get_img()

    image_data = MockTaggedData("", img)

    image_ds_object = DataSetObject(image_data,
                                    output_function=classification_output)
    image_ds_object.label = 5

    img_output, label = image_ds_object.output()

    assert np.all(img_output == img)
    assert label == 5
示例#3
0
def test_image():
    rgb_img = get_img(width=300, height=200)
    rgb_tagged_data = MockTaggedData("", rgb_img)

    image_ds_object_01 = DataSetObject(rgb_tagged_data)

    assert image_ds_object_01.width == 300
    assert image_ds_object_01.height == 200

    bw_img = get_img(width=300, height=200, bw=True)
    bw_tagged_data = MockTaggedData("", bw_img)

    image_ds_object_03 = DataSetObject(bw_tagged_data)

    assert image_ds_object_03.width == 300
    assert image_ds_object_03.height == 200
示例#4
0
def detection_collection(
    image_count=15,
    max_annotations_per_image=5,
    img_width=1280,
    img_height=720,
    amount_of_classes=3,
):
    images = [
        DataSetObject(tagged_data=MockTaggedData(f"{i}.jpg", get_img()))
        for i in range(image_count)
    ]
    for image in images:
        img_annotation_iterations = random.randint(
            1, max_annotations_per_image - 1)
        for _ in range(img_annotation_iterations):
            image.add_annotation(
                get_random_bounding_box(
                    img_width=img_width,
                    img_height=img_height,
                    amount_of_classes=amount_of_classes,
                ))
    label_mapping = {
        label_index: f"label_{label_index}"
        for label_index in range(amount_of_classes)
    }
    return images, label_mapping
示例#5
0
def mask_objects() -> List[DataSetObject]:
    inputs = []
    for subset_name in ["train", "test", "val"]:
        for img_type in ["mask", "image"]:
            for img_index in range(10):
                tagged_data = MockTaggedData(
                    relative_path=f"{subset_name}/{img_type}/{img_index}.jpg",
                    data="{subset_name}{img_index}",
                )
                inputs.append(DataSetObject(tagged_data))

    return inputs
示例#6
0
    def transform_single_object(
            self, ds_input: DataSetObject) -> List[DataSetObject]:
        values = extract_values(template=self.template,
                                path=ds_input.relative_path)
        if self.input_item in values:
            label = values[self.input_item]

            ds_input.label = label

            return [ds_input]
        elif self.drop_if_no_match:
            return []
        else:
            return [ds_input]
示例#7
0
def test_mask():
    img = get_img()
    mask = get_img()

    image_data = MockTaggedData("", img)
    mask_data = MockTaggedData("", mask)

    image_ds_object = DataSetObject(image_data, output_function=single_mask_output)
    image_ds_object.annotations.append(Mask(mask_data))

    ds = Dataset()
    ds.form_from_ds_objects([image_ds_object])

    img_output, mask_output = ds[0]

    assert np.all(img_output == img)
    assert np.all(mask_output == mask)
示例#8
0
    def transform_single_object(
            self, dataset_object: DataSetObject) -> List[DataSetObject]:
        dataset_object.add_cache_manager(self.cache_manager)

        return [dataset_object]
示例#9
0
def triple_output(ds_object: DataSetObject) -> int:
    parsed_output = ds_object.get_parsed()
    assert isinstance(parsed_output, int)
    return parsed_output * 3
示例#10
0
文件: base.py 项目: GiteZz/poif
 def form(self, data: List[TaggedData]):
     inputs = [DataSetObject(tagged_data, output_function=self.output_function) for tagged_data in data]
     self.form_from_ds_objects(inputs)