Beispiel #1
0
def copy_fixtures():
    # Upload file to mocked S3 bucket
    s3_client = get_s3_client()

    s3_client.create_bucket(Bucket=BUCKET)
    s3_client.create_bucket(Bucket=DATA_LAKE_BUCKET)
    s3_client.create_bucket(Bucket=TILE_CACHE_BUCKET)

    upload_fake_data(**FAKE_INT_DATA_PARAMS)
    upload_fake_data(**FAKE_FLOAT_DATA_PARAMS)

    s3_client.upload_file(GEOJSON_PATH, BUCKET, GEOJSON_NAME)
    s3_client.upload_file(TSV_PATH, BUCKET, TSV_NAME)
    s3_client.upload_file(SHP_PATH, BUCKET, SHP_NAME)
    s3_client.upload_file(APPEND_TSV_PATH, BUCKET, APPEND_TSV_NAME)

    # upload a separate for each row so we can test running large numbers of sources in parallel
    reader = csv.DictReader(open(TSV_PATH, newline=""), delimiter="\t")
    for row in reader:
        out = io.StringIO(newline="")
        writer = csv.DictWriter(out,
                                delimiter="\t",
                                fieldnames=reader.fieldnames)
        writer.writeheader()
        writer.writerow(row)

        s3_client.put_object(
            Body=str.encode(out.getvalue()),
            Bucket=BUCKET,
            Key=f"test_{reader.line_num}.tsv",
        )
        out.close()
Beispiel #2
0
async def test_delete_raster_tileset_assets():
    s3_client = get_s3_client()
    dataset = "test_delete_raster_tileset"
    version = "table"
    srid = "epsg-4326"
    grid = "10/40000"
    value = "year"

    for i in range(0, 10):
        s3_client.upload_file(
            TSV_PATH,
            DATA_LAKE_BUCKET,
            f"{dataset}/{version}/raster/{srid}/{grid}/{value}/test_{i}.tsv",
        )

    response = s3_client.list_objects_v2(Bucket=DATA_LAKE_BUCKET,
                                         Prefix=dataset)

    assert response["KeyCount"] == 10

    await delete_raster_tileset_assets(dataset, version, srid, grid, value)

    response = s3_client.list_objects_v2(Bucket=DATA_LAKE_BUCKET,
                                         Prefix=dataset)
    assert response["KeyCount"] == 0
Beispiel #3
0
def check_s3_file_present(bucket, keys):
    s3_client = get_s3_client()

    for key in keys:
        try:
            s3_client.head_object(Bucket=bucket, Key=key)
        except ClientError:
            raise AssertionError(
                f"Object {key} doesn't exist in bucket {bucket}!")
Beispiel #4
0
def test_delete_s3_objects():
    """" Make sure we can delete more than 1000 items."""

    s3_client = get_s3_client()

    for i in range(1001):
        s3_client.upload_file(TSV_PATH, BUCKET, "TEST_DELETE_S3_OBJECTS" + str(i))

    count = delete_s3_objects(BUCKET, "TEST_DELETE_S3_OBJECTS")
    assert count == 1001
async def get_extent(asset_id: UUID) -> Optional[Extent]:
    asset_row: ORMAsset = await get_asset(asset_id)
    asset_uri: str = get_asset_uri(
        asset_row.dataset,
        asset_row.version,
        asset_row.asset_type,
        asset_row.creation_options,
        srid=infer_srid_from_grid(asset_row.creation_options.get("grid")),
    )
    bucket, key = split_s3_path(tile_uri_to_extent_geojson(asset_uri))

    s3_client = get_s3_client()
    resp = s3_client.get_object(Bucket=bucket, Key=key)
    extent_geojson: Dict[str,
                         Any] = json.loads(resp["Body"].read().decode("utf-8"))

    if extent_geojson:
        return Extent(**extent_geojson)
    return None
Beispiel #6
0
def is_zipped(s3_uri: str) -> bool:
    """Get basename of source file.

    If Zipfile, add VSIZIP prefix for GDAL
    """
    bucket, key = split_s3_path(s3_uri)
    client = get_s3_client()
    _, ext = os.path.splitext(s3_uri)

    try:
        header = client.head_object(Bucket=bucket, Key=key)
        # TODO: moto does not return the correct ContentType so have to go for the ext
        if header[
                "ContentType"] == "application/x-zip-compressed" or ext == ".zip":
            return True
    except (KeyError, ClientError):
        raise FileNotFoundError(f"Cannot access source file {s3_uri}")

    return False
async def _get_raster_stats(asset_id: UUID) -> RasterStats:
    asset_row: ORMAsset = await get_asset(asset_id)

    asset_uri: str = get_asset_uri(
        asset_row.dataset,
        asset_row.version,
        asset_row.asset_type,
        asset_row.creation_options,
        srid=infer_srid_from_grid(asset_row.creation_options.get("grid")),
    )
    bucket, tiles_key = split_s3_path(tile_uri_to_tiles_geojson(asset_uri))

    s3_client = get_s3_client()
    tiles_resp = s3_client.get_object(Bucket=bucket, Key=tiles_key)
    tiles_geojson: Dict[str, Any] = json.loads(
        tiles_resp["Body"].read().decode("utf-8"))

    bandstats: List[BandStats] = _collect_bandstats(
        FeatureCollection(**tiles_geojson))

    return RasterStats(bands=bandstats)
Beispiel #8
0
def upload_fake_data(dtype, dtype_name, no_data, prefix, data):
    s3_client = get_s3_client()

    data_file_name = "0000000000-0000000000.tif"

    tiles_geojson = {
        "type":
        "FeatureCollection",
        "features": [{
            "type": "Feature",
            "geometry": {
                "type":
                "Polygon",
                "coordinates": [[
                    [10.0, 10.0],
                    [12.0, 11.0],
                    [12.0, 10.0],
                    [11.0, 10.0],
                    [11.0, 11.0],
                ]],
            },
            "properties": {
                "name": f"/vsis3/{DATA_LAKE_BUCKET}/{prefix}/{data_file_name}"
            },
        }],
    }

    dataset_profile = {
        "driver":
        "GTiff",
        "dtype":
        dtype,
        "nodata":
        no_data,
        "count":
        1,
        "width":
        300,
        "height":
        300,
        # "blockxsize": 100,
        # "blockysize": 100,
        "crs":
        CRS.from_epsg(4326),
        # 0.003332345971563981 is the pixel size of 90/27008
        "transform":
        Affine(0.003332345971563981, 0, 10, 0, -0.003332345971563981, 10),
    }

    with tempfile.TemporaryDirectory() as tmpdir:
        full_tiles_path = f"{os.path.join(tmpdir, 'tiles.geojson')}"

        with open(full_tiles_path, "w") as dst:
            dst.write(json.dumps(tiles_geojson))
        s3_client.upload_file(
            full_tiles_path,
            DATA_LAKE_BUCKET,
            f"{prefix}/tiles.geojson",
        )

        full_data_file_path = f"{os.path.join(tmpdir, data_file_name)}"
        with rasterio.Env():
            with rasterio.open(full_data_file_path, "w",
                               **dataset_profile) as dst:
                dst.write(data.astype(dtype), 1)
        s3_client.upload_file(
            full_data_file_path,
            DATA_LAKE_BUCKET,
            f"{prefix}/{data_file_name}",
        )
Beispiel #9
0
def delete_s3_files(bucket, prefix):
    s3_client = get_s3_client()
    response = s3_client.list_objects_v2(Bucket=bucket, Prefix=prefix)
    for obj in response.get("Contents", list()):
        print("Deleting", obj["Key"])
        s3_client.delete_object(Bucket=bucket, Key=obj["Key"])
Beispiel #10
0
async def test_version_put_raster(async_client):
    """Test raster source version operations."""

    dataset = "test_version_put_raster"
    version = "v1.0.0"

    s3_client = get_s3_client()

    pixetl_output_files = [
        f"{dataset}/{version}/raster/epsg-4326/90/27008/percent/gdal-geotiff/extent.geojson",
        f"{dataset}/{version}/raster/epsg-4326/90/27008/percent/geotiff/extent.geojson",
        f"{dataset}/{version}/raster/epsg-4326/90/27008/percent/gdal-geotiff/tiles.geojson",
        f"{dataset}/{version}/raster/epsg-4326/90/27008/percent/geotiff/tiles.geojson",
        f"{dataset}/{version}/raster/epsg-4326/90/27008/percent/gdal-geotiff/90N_000E.tif",
        f"{dataset}/{version}/raster/epsg-4326/90/27008/percent/geotiff/90N_000E.tif",
    ]

    for key in pixetl_output_files:
        s3_client.delete_object(Bucket="gfw-data-lake-test", Key=key)

    raster_version_payload = {
        "creation_options": {
            "source_type":
            "raster",
            "source_uri": [
                f"s3://{DATA_LAKE_BUCKET}/{FAKE_INT_DATA_PARAMS['prefix']}/tiles.geojson"
            ],
            "source_driver":
            "GeoTIFF",
            "data_type":
            FAKE_INT_DATA_PARAMS["dtype_name"],
            "no_data":
            FAKE_INT_DATA_PARAMS["no_data"],
            "pixel_meaning":
            "percent",
            "grid":
            "90/27008",
            "resampling":
            "nearest",
            "overwrite":
            True,
            "subset":
            "90N_000E",
        },
        "metadata": payload["metadata"],
    }

    await create_default_asset(
        dataset,
        version,
        version_payload=raster_version_payload,
        async_client=async_client,
        execute_batch_jobs=True,
    )

    for key in pixetl_output_files:
        try:
            s3_client.head_object(Bucket="gfw-data-lake-test", Key=key)
        except ClientError:
            raise AssertionError(f"Key {key} doesn't exist!")

    # test to download assets
    response = await async_client.get(
        f"/dataset/{dataset}/{version}/download/geotiff",
        params={
            "grid": "90/27008",
            "tile_id": "90N_000E",
            "pixel_meaning": "percent"
        },
        allow_redirects=False,
    )
    assert response.status_code == 307
    url = urlparse(response.headers["Location"])
    assert url.scheme == "http"
    assert url.netloc == urlparse(S3_ENTRYPOINT_URL).netloc
    assert (
        url.path ==
        f"/gfw-data-lake-test/{dataset}/{version}/raster/epsg-4326/90/27008/percent/geotiff/90N_000E.tif"
    )
    assert "AWSAccessKeyId" in url.query
    assert "Signature" in url.query
    assert "Expires" in url.query

    response = await async_client.get(
        f"/dataset/{dataset}/{version}/download/geotiff",
        params={
            "grid": "10/40000",
            "tile_id": "90N_000E",
            "pixel_meaning": "percent"
        },
        allow_redirects=False,
    )
    assert response.status_code == 404
Beispiel #11
0
from app.utils.aws import get_s3_client
from tests import BUCKET, DATA_LAKE_BUCKET, SHP_NAME
from tests.conftest import FAKE_FLOAT_DATA_PARAMS, FAKE_INT_DATA_PARAMS
from tests.tasks import MockCloudfrontClient
from tests.utils import (
    check_s3_file_present,
    check_tasks_status,
    create_dataset,
    create_default_asset,
    create_version,
    delete_s3_files,
    generate_uuid,
    poll_jobs,
)

s3_client = get_s3_client()


@pytest.mark.asyncio
async def test_assets(async_client):
    """Basic tests of asset endpoint behavior."""
    # Add a dataset, version, and default asset
    dataset = "test_assets"
    version = "v20200626"

    asset = await create_default_asset(dataset,
                                       version,
                                       async_client=async_client,
                                       execute_batch_jobs=False)
    asset_id = asset["asset_id"]
Beispiel #12
0
async def test_get_aws_files():
    good_bucket = "good_bucket"
    good_prefix = "good_prefix"

    s3_client = get_s3_client()

    s3_client.create_bucket(Bucket=good_bucket)
    s3_client.put_object(
        Bucket=good_bucket, Key=f"{good_prefix}/world.tif", Body="booga booga!"
    )

    # Import this inside the test function so we're covered
    # by the mock_s3 decorator
    from app.utils.aws import get_aws_files

    keys = get_aws_files(good_bucket, good_prefix)
    assert len(keys) == 1
    assert keys[0] == f"/vsis3/{good_bucket}/{good_prefix}/world.tif"

    keys = get_aws_files(good_bucket, good_prefix, extensions=[".pdf"])
    assert len(keys) == 0

    keys = get_aws_files(good_bucket, "bad_prefix")
    assert len(keys) == 0

    keys = get_aws_files("bad_bucket", "doesnt_matter")
    assert len(keys) == 0

    s3_client.put_object(
        Bucket=good_bucket, Key=f"{good_prefix}/another_world.csv", Body="booga booga!"
    )

    keys = get_aws_files(good_bucket, good_prefix)
    assert len(keys) == 2
    assert f"/vsis3/{good_bucket}/{good_prefix}/another_world.csv" in keys
    assert f"/vsis3/{good_bucket}/{good_prefix}/world.tif" in keys

    keys = get_aws_files(good_bucket, good_prefix, extensions=[".csv"])
    assert len(keys) == 1
    assert keys[0] == f"/vsis3/{good_bucket}/{good_prefix}/another_world.csv"

    keys = get_aws_files(good_bucket, good_prefix, limit=1)
    assert len(keys) == 1
    assert (
        f"/vsis3/{good_bucket}/{good_prefix}/another_world.csv" in keys
        or f"/vsis3/{good_bucket}/{good_prefix}/world.tif" in keys
    )

    s3_client.put_object(
        Bucket=good_bucket, Key=f"{good_prefix}/coverage_layer.tif", Body="booga booga!"
    )
    keys = get_aws_files(good_bucket, good_prefix)
    assert len(keys) == 3
    assert f"/vsis3/{good_bucket}/{good_prefix}/another_world.csv" in keys
    assert f"/vsis3/{good_bucket}/{good_prefix}/coverage_layer.tif" in keys
    assert f"/vsis3/{good_bucket}/{good_prefix}/world.tif" in keys

    keys = get_aws_files(
        good_bucket, good_prefix, exit_after_max=1, extensions=[".tif"]
    )
    assert len(keys) == 1
    assert (
        f"/vsis3/{good_bucket}/{good_prefix}/coverage_layer.tif" in keys
        or f"/vsis3/{good_bucket}/{good_prefix}/world.tif" in keys
    )
async def test_vector_tile_asset(
    mocked_cloudfront_client, ecs_client, batch_client, async_client
):
    _, logs = batch_client
    ecs_client.return_value = MockECSClient()
    ############################
    # Setup test
    ############################

    dataset = "test"
    source = SHP_NAME

    version = "v1.1.1"
    input_data = {
        "creation_options": {
            "source_type": "vector",
            "source_uri": [f"s3://{BUCKET}/{source}"],
            "source_driver": "GeoJSON",
            "create_dynamic_vector_tile_cache": True,
        },
        "metadata": {},
    }

    await create_default_asset(
        dataset,
        version,
        version_payload=input_data,
        async_client=async_client,
        logs=logs,
        execute_batch_jobs=True,
        skip_dataset=False,
    )

    ### Create static tile cache asset
    httpx.delete(f"http://localhost:{PORT}")

    input_data = {
        "asset_type": "Static vector tile cache",
        "is_managed": True,
        "creation_options": {
            "min_zoom": 0,
            "max_zoom": 9,
            "tile_strategy": "discontinuous",
            "layer_style": [
                {
                    "id": dataset,
                    "paint": {"fill-color": "#9c9c9c", "fill-opacity": 0.8},
                    "source-layer": dataset,
                    "source": dataset,
                    "type": "fill",
                }
            ],
        },
    }

    response = await async_client.post(
        f"/dataset/{dataset}/{version}/assets", json=input_data
    )
    print(response.json())
    assert response.status_code == 202
    asset_id = response.json()["data"]["asset_id"]

    # get tasks id from change log and wait until finished
    response = await async_client.get(f"/asset/{asset_id}/change_log")

    assert response.status_code == 200
    tasks = json.loads(response.json()["data"][-1]["detail"])
    task_ids = [task["job_id"] for task in tasks]
    print(task_ids)

    # make sure, all jobs completed
    status = await poll_jobs(task_ids, logs=logs, async_client=async_client)
    assert status == "saved"

    response = await async_client.get(f"/dataset/{dataset}/{version}/assets")
    assert response.status_code == 200

    # there should be 4 assets now (geodatabase table, dynamic vector tile cache, ndjson and static vector tile cache)
    assert len(response.json()["data"]) == 4

    # there should be 10 files on s3 including the root.json and VectorTileServer files
    s3_client = get_s3_client()
    resp = s3_client.list_objects_v2(
        Bucket=TILE_CACHE_BUCKET, Prefix=f"{dataset}/{version}/default/"
    )
    print(resp)
    assert resp["KeyCount"] == 10

    response = await async_client.get(
        f"/dataset/{dataset}/{version}/assets?asset_type=ndjson"
    )
    assert response.status_code == 200
    assert len(response.json()["data"]) == 1
    asset_id = response.json()["data"][0]["asset_id"]

    # Check if file is in data lake
    resp = s3_client.list_objects_v2(
        Bucket=DATA_LAKE_BUCKET, Prefix=f"{dataset}/{version}/vector/"
    )
    print(resp)
    assert resp["KeyCount"] == 1

    response = await async_client.delete(f"/asset/{asset_id}")
    assert response.status_code == 200

    # Check if file was deleted
    resp = s3_client.list_objects_v2(
        Bucket=DATA_LAKE_BUCKET, Prefix=f"{dataset}/{version}/vector/"
    )
    print(resp)
    assert resp["KeyCount"] == 0

    ###########
    # 1x1 Grid
    ###########
    ### Create static tile cache asset
    httpx.delete(f"http://localhost:{PORT}")

    input_data = {
        "asset_type": "1x1 grid",
        "is_managed": True,
        "creation_options": {},
    }

    response = await async_client.post(
        f"/dataset/{dataset}/{version}/assets", json=input_data
    )

    print(response.json())
    assert response.status_code == 202
    asset_id = response.json()["data"]["asset_id"]

    # get tasks id from change log and wait until finished
    response = await async_client.get(f"/asset/{asset_id}/change_log")

    assert response.status_code == 200
    tasks = json.loads(response.json()["data"][-1]["detail"])
    task_ids = [task["job_id"] for task in tasks]
    print(task_ids)

    # make sure, all jobs completed
    status = await poll_jobs(task_ids, logs=logs, async_client=async_client)
    assert status == "saved"

    response = await async_client.get(f"/dataset/{dataset}/{version}/assets")
    assert response.status_code == 200

    # there should be 4 assets now (geodatabase table, dynamic vector tile cache and static vector tile cache (already deleted ndjson)
    assert len(response.json()["data"]) == 4

    # Check if file is in tile cache
    resp = s3_client.list_objects_v2(
        Bucket=DATA_LAKE_BUCKET, Prefix=f"{dataset}/{version}/vector/"
    )
    print(resp)
    assert resp["KeyCount"] == 1

    response = await async_client.delete(f"/asset/{asset_id}")
    print(response.json())
    assert response.status_code == 200

    # Check if file was deleted
    resp = s3_client.list_objects_v2(
        Bucket=DATA_LAKE_BUCKET, Prefix=f"{dataset}/{version}/vector/"
    )
    print(resp)
    assert resp["KeyCount"] == 0

    ###########
    # Vector file export
    ###########

    asset_types = [AssetType.shapefile, AssetType.geopackage]
    for asset_type in asset_types:
        response = await async_client.get(f"/dataset/{dataset}/{version}/assets")
        current_asset_count = len(response.json()["data"])

        httpx.delete(f"http://localhost:{PORT}")

        input_data = {
            "asset_type": asset_type,
            "is_managed": True,
            "creation_options": {},
        }

        response = await async_client.post(
            f"/dataset/{dataset}/{version}/assets", json=input_data
        )

        print(response.json())
        assert response.status_code == 202
        asset_id = response.json()["data"]["asset_id"]

        await check_tasks_status(async_client, logs, [asset_id])

        response = await async_client.get(f"/dataset/{dataset}/{version}/assets")
        assert response.status_code == 200

        # there should be one more asset than before this test
        assert len(response.json()["data"]) == current_asset_count + 1

        # Check if file is in data lake
        resp = s3_client.list_objects_v2(
            Bucket=DATA_LAKE_BUCKET, Prefix=f"{dataset}/{version}/vector/"
        )
        assert resp["KeyCount"] == 1

        # test to download assets
        fmt = "shp" if asset_type == AssetType.shapefile else "gpkg"
        ext = "shp.zip" if asset_type == AssetType.shapefile else "gpkg"
        response = await async_client.get(
            f"/dataset/{dataset}/{version}/download/{fmt}", allow_redirects=False
        )
        assert response.status_code == 307
        url = urlparse(response.headers["Location"])
        assert url.scheme == "http"
        assert url.netloc == urlparse(S3_ENTRYPOINT_URL).netloc
        assert (
            url.path
            == f"/gfw-data-lake-test/{dataset}/{version}/vector/epsg-4326/{dataset}_{version}.{ext}"
        )
        assert "AWSAccessKeyId" in url.query
        assert "Signature" in url.query
        assert "Expires" in url.query

        response = await async_client.delete(f"/asset/{asset_id}")
        print(response.json())
        assert response.status_code == 200

        # Check if file was deleted
        resp = s3_client.list_objects_v2(
            Bucket=DATA_LAKE_BUCKET, Prefix=f"{dataset}/{version}/vector/"
        )
        print(resp)
        assert resp["KeyCount"] == 0

    mocked_cloudfront_client.return_value = MockCloudfrontClient()
    response = await async_client.patch(
        f"/dataset/{dataset}/{version}", json={"is_latest": True}
    )
    assert response.status_code == 200
    assert mocked_cloudfront_client.called
Beispiel #14
0
async def test_table_source_asset_parallel(batch_client, async_client):
    _, logs = batch_client

    ############################
    # Setup test
    ############################

    dataset = "table_test"
    version = "v202002.1"

    s3_client = get_s3_client()

    for i in range(2, 101):
        s3_client.upload_file(TSV_PATH, BUCKET, f"test_{i}.tsv")

    # define partition schema
    partition_schema = list()
    years = range(2018, 2021)
    for year in years:
        for week in range(1, 54):
            try:
                name = f"y{year}_w{week:02}"
                start = pendulum.parse(f"{year}-W{week:02}").to_date_string()
                end = pendulum.parse(f"{year}-W{week:02}").add(
                    days=7).to_date_string()
                partition_schema.append({
                    "partition_suffix": name,
                    "start_value": start,
                    "end_value": end
                })

            except ParserError:
                # Year has only 52 weeks
                pass

    input_data = {
        "creation_options": {
            "source_type":
            "table",
            "source_uri": [f"s3://{BUCKET}/{TSV_NAME}"] +
            [f"s3://{BUCKET}/test_{i}.tsv" for i in range(2, 101)],
            "source_driver":
            "text",
            "delimiter":
            "\t",
            "has_header":
            True,
            "latitude":
            "latitude",
            "longitude":
            "longitude",
            "cluster": {
                "index_type": "gist",
                "column_names": ["geom_wm"]
            },
            "partitions": {
                "partition_type": "range",
                "partition_column": "alert__date",
                "partition_schema": partition_schema,
            },
            "indices": [
                {
                    "index_type": "gist",
                    "column_names": ["geom"]
                },
                {
                    "index_type": "gist",
                    "column_names": ["geom_wm"]
                },
                {
                    "index_type": "btree",
                    "column_names": ["alert__date"]
                },
            ],
            "table_schema": [
                {
                    "field_name": "rspo_oil_palm__certification_status",
                    "field_type": "text",
                },
                {
                    "field_name": "per_forest_concession__type",
                    "field_type": "text"
                },
                {
                    "field_name": "idn_forest_area__type",
                    "field_type": "text"
                },
                {
                    "field_name": "alert__count",
                    "field_type": "integer"
                },
                {
                    "field_name": "adm1",
                    "field_type": "integer"
                },
                {
                    "field_name": "adm2",
                    "field_type": "integer"
                },
            ],
        },
        "metadata": {},
    }

    #####################
    # Test asset creation
    #####################

    asset = await create_default_asset(
        dataset,
        version,
        version_payload=input_data,
        execute_batch_jobs=True,
        logs=logs,
        async_client=async_client,
    )
    asset_id = asset["asset_id"]

    await check_version_status(dataset, version, 3)
    await check_asset_status(dataset, version, 1)
    await check_task_status(asset_id, 26, "cluster_partitions_3")

    # There should be a table called "table_test"."v202002.1" with 99 rows.
    # It should have the right amount of partitions and indices
    async with ContextEngine("READ"):
        count = await db.scalar(
            db.text(f"""
                    SELECT count(*)
                        FROM "{dataset}"."{version}";"""))
        partition_count = await db.scalar(
            db.text(f"""
                    SELECT count(i.inhrelid::regclass)
                        FROM pg_inherits i
                        WHERE  i.inhparent = '"{dataset}"."{version}"'::regclass;"""
                    ))
        index_count = await db.scalar(
            db.text(f"""
                    SELECT count(indexname)
                        FROM pg_indexes
                        WHERE schemaname = '{dataset}' AND tablename like '{version}%';"""
                    ))
        cluster_count = await db.scalar(
            db.text("""
                    SELECT count(relname)
                        FROM   pg_class c
                        JOIN   pg_index i ON i.indrelid = c.oid
                        WHERE  relkind = 'r' AND relhasindex AND i.indisclustered"""
                    ))

    assert count == 9900
    assert partition_count == len(partition_schema)
    # postgres12 also addes indices to the main table, hence there are more indices than partitions
    assert index_count == (partition_count + 1) * len(
        input_data["creation_options"]["indices"])
    assert cluster_count == len(partition_schema)
async def test_table_source_asset(batch_client, httpd):
    _, logs = batch_client
    httpd_port = httpd.server_port

    ############################
    # Setup test
    ############################

    s3_client = get_s3_client()

    s3_client.create_bucket(Bucket=BUCKET)
    s3_client.upload_file(TSV_PATH, BUCKET, TSV_NAME)

    dataset = "table_test"
    version = "v202002.1"

    # define partition schema
    partition_schema = list()
    years = range(2018, 2021)
    for year in years:
        for week in range(1, 54):
            try:
                name = f"y{year}_w{week:02}"
                start = pendulum.parse(f"{year}-W{week:02}").to_date_string()
                end = pendulum.parse(f"{year}-W{week:02}").add(
                    days=7).to_date_string()
                partition_schema.append({
                    "partition_suffix": name,
                    "start_value": start,
                    "end_value": end
                })

            except ParserError:
                # Year has only 52 weeks
                pass

    input_data = {
        "source_type": "table",
        "source_uri": [f"s3://{BUCKET}/{TSV_NAME}"],
        "is_mutable": True,
        "creation_options": {
            "src_driver":
            "text",
            "delimiter":
            "\t",
            "has_header":
            True,
            "latitude":
            "latitude",
            "longitude":
            "longitude",
            "cluster": {
                "index_type": "gist",
                "column_name": "geom_wm"
            },
            "partitions": {
                "partition_type": "range",
                "partition_column": "alert__date",
                "partition_schema": partition_schema,
            },
            "indices": [
                {
                    "index_type": "gist",
                    "column_name": "geom"
                },
                {
                    "index_type": "gist",
                    "column_name": "geom_wm"
                },
                {
                    "index_type": "btree",
                    "column_name": "alert__date"
                },
            ],
            "table_schema": [
                {
                    "field_name": "rspo_oil_palm__certification_status",
                    "field_type": "text",
                },
                {
                    "field_name": "per_forest_concession__type",
                    "field_type": "text"
                },
                {
                    "field_name": "idn_forest_area__type",
                    "field_type": "text"
                },
                {
                    "field_name": "alert__count",
                    "field_type": "integer"
                },
                {
                    "field_name": "adm1",
                    "field_type": "integer"
                },
                {
                    "field_name": "adm2",
                    "field_type": "integer"
                },
            ],
        },
        "metadata": {},
    }

    await create_dataset(dataset)
    await create_version(dataset, version, input_data)

    #####################
    # Test asset creation
    #####################

    # Create default asset in mocked BATCH
    async with ContextEngine("WRITE"):
        asset_id = await create_default_asset(
            dataset,
            version,
            input_data,
            None,
        )

    tasks_rows = await tasks.get_tasks(asset_id)
    task_ids = [str(task.task_id) for task in tasks_rows]

    # make sure, all jobs completed
    status = await poll_jobs(task_ids)

    # Get the logs in case something went wrong
    _print_logs(logs)
    check_callbacks(task_ids, httpd_port)

    assert status == "saved"

    await _check_version_status(dataset, version)
    await _check_asset_status(dataset, version, 1)
    await _check_task_status(asset_id, 14, "cluster_partitions_3")

    # There should be a table called "table_test"."v202002.1" with 99 rows.
    # It should have the right amount of partitions and indices
    async with ContextEngine("READ"):
        count = await db.scalar(
            db.text(f"""
                    SELECT count(*)
                        FROM "{dataset}"."{version}";"""))
        partition_count = await db.scalar(
            db.text(f"""
                    SELECT count(i.inhrelid::regclass)
                        FROM pg_inherits i
                        WHERE  i.inhparent = '"{dataset}"."{version}"'::regclass;"""
                    ))
        index_count = await db.scalar(
            db.text(f"""
                    SELECT count(indexname)
                        FROM pg_indexes
                        WHERE schemaname = '{dataset}' AND tablename like '{version}%';"""
                    ))
        cluster_count = await db.scalar(
            db.text("""
                    SELECT count(relname)
                        FROM   pg_class c
                        JOIN   pg_index i ON i.indrelid = c.oid
                        WHERE  relkind = 'r' AND relhasindex AND i.indisclustered"""
                    ))

    assert count == 99
    assert partition_count == len(partition_schema)
    assert index_count == partition_count * len(
        input_data["creation_options"]["indices"])
    assert cluster_count == len(partition_schema)

    append_data = {
        "source_uri": [f"s3://{BUCKET}/{APPEND_TSV_NAME}"],
        "source_type": "table",
        "creation_options": input_data["creation_options"]
    }

    # Create default asset in mocked BATCH
    async with ContextEngine("WRITE"):
        await append_default_asset(dataset, version, append_data, asset_id)

    tasks_rows = await tasks.get_tasks(asset_id)
    task_ids = [str(task.task_id) for task in tasks_rows]
    print(task_ids)

    # make sure, all jobs completed
    status = await poll_jobs(task_ids)

    # Get the logs in case something went wrong
    _print_logs(logs)
    check_callbacks(task_ids, httpd_port)

    assert status == "saved"

    await _check_version_status(dataset, version, 2)
    await _check_asset_status(dataset, version, 2)

    # The table should now have 101 rows after append
    async with ContextEngine("READ"):
        count = await db.scalar(
            db.text(f"""
                    SELECT count(*)
                        FROM "{dataset}"."{version}";"""))

    assert count == 101