예제 #1
0
def test_image_copy(spark: SparkSession, tmpdir):
    source_image = os.path.join(tmpdir, "source_image")
    with open(source_image, "w") as fobj:
        fobj.write("abc")
    os.makedirs(os.path.join(tmpdir, "out"))

    df = spark.createDataFrame([(Image(source_image), )],
                               ["image"])  # type: pyspark.sql.DataFrame
    df = df.withColumn(
        "image",
        image_copy(col("image"), lit(os.path.join(tmpdir, "out/"))),
    )
    data = df.collect()  # force lazy calculation
    out_file = os.path.join(tmpdir, "out", "source_image")
    assert Image(out_file) == data[0].image

    with open(os.path.join(out_file)) as fobj:
        assert fobj.read() == "abc"
예제 #2
0
    def test_image_copy(self):
        source_image = os.path.join(self.test_dir, "source_image")
        with open(source_image, "w") as fobj:
            fobj.write("abc")
        os.makedirs(os.path.join(self.test_dir, "out"))

        df = self.spark.createDataFrame(
            [(Image(source_image), )],
            ["image"])  # type: pyspark.sql.DataFrame
        df = df.withColumn(
            "image",
            image_copy(col("image"), lit(os.path.join(self.test_dir, "out/"))),
        )
        data = df.collect()  # force lazy calculation
        out_file = os.path.join(self.test_dir, "out", "source_image")
        self.assertEqual(Image(out_file), data[0].image)

        with open(os.path.join(out_file)) as fobj:
            self.assertEqual("abc", fobj.read())
예제 #3
0
def convert(
    spark: SparkSession,
    dataset_root: str,
    limit: int = 0,
    asset_dir: Optional[str] = None,
) -> DataFrame:
    """Convert a Coco Dataset into Rikai dataset.

    This function expects the COCO datasets are stored in directory with the
    following structure:

    - dataset
        - annotations
          - captions_train2017.json
          - instances_train2017.json
          - ...
        - train2017
        - val2017
        - test2017

    Parameters
    ----------
    spark : SparkSession
        A live spark session
    dataset_root : str
        The directory of dataset
    limit : int, optional
        The number of images of each split to be converted.
    asset_dir : str, optional
        The asset directory to store images, can be a s3 directory.

    Return
    ------
    DataFrame
        Returns a Spark DataFrame
    """
    train_json = os.path.join(dataset_root, "annotations",
                              "instances_train2017.json")
    val_json = os.path.join(dataset_root, "annotations",
                            "instances_val2017.json")

    categories = load_categories(train_json)

    examples = []
    for split, anno_file in zip(["train", "val"], [train_json, val_json]):
        coco = COCO(annotation_file=anno_file)
        # Coco has native dependencies, so we do not distributed them
        # to the workers.
        image_ids = coco.imgs
        if limit > 0:
            image_ids = islice(image_ids, limit)
        for image_id in image_ids:
            ann_id = coco.getAnnIds(imgIds=image_id)
            annotations = coco.loadAnns(ann_id)
            annos = []
            for ann in annotations:
                bbox = Box2d(*ann["bbox"])
                annos.append({
                    "category_id":
                    ann["category_id"],
                    "category_text":
                    categories[ann["category_id"]]["name"],
                    "bbox":
                    bbox,
                    "area":
                    float(ann["area"]),
                })
            image_payload = coco.loadImgs(ids=image_id)[0]
            example = {
                "image_id":
                image_id,
                "annotations":
                annos,
                "image":
                Image(
                    os.path.abspath(
                        os.path.join(
                            os.curdir,
                            "dataset",
                            "{}2017".format(split),
                            image_payload["file_name"],
                        ))),
                "split":
                split,
            }
            examples.append(example)

    schema = StructType([
        StructField("image_id", LongType(), False),
        StructField(
            "annotations",
            ArrayType(
                StructType([
                    StructField("category_id", IntegerType()),
                    StructField("category_text", StringType()),
                    StructField("area", FloatType()),
                    StructField("bbox", Box2dType()),
                ])),
            False,
        ),
        StructField("image", ImageType(), False),
        StructField("split", StringType(), False),
    ])
    df = spark.createDataFrame(examples, schema=schema)
    if asset_dir:
        asset_dir = asset_dir if asset_dir.endswith("/") else asset_dir + "/"
        print("ASSET DIR: ", asset_dir)
        df = df.withColumn("image", image_copy(col("image"), lit(asset_dir)))
    return df