Example #1
0
 def test_partitioned_totals_approx_equal(self):
     # dodgy, hand-crafted, and should be replaced by something rigid
     splits = split_df(self.data_fake, 10)
     totals = [
         self.data_fake[self.data_fake.tile.isin(s)].frac.sum()
         for s in splits
     ]
     assert [45] * len(totals) == pytest.approx(totals, abs=5)
Example #2
0
 def test_number_of_partitions_as_requested(self):
     result = split_df(self.data, 3)
     assert len(result) == math.ceil(len(self.data) / 3)
Example #3
0
 def test_total_size_unchanged(self):
     result = split_df(self.data, 3)
     assert len(reduce(lambda z, y: z + y, result)) == len(self.data)
Example #4
0
 def test_catch_tiles_without_deadtrees(self):
     with pytest.raises(ValueError):
         split_df(self.data_bad, 3)
Example #5
0
 def test_catch_invalid_size(self, n):
     with pytest.raises(ValueError):
         split_df(self.data, n)
Example #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("image_dir", type=Path)
    parser.add_argument("mask_dir", type=Path)
    parser.add_argument("lu_dir", type=Path)
    parser.add_argument("outdir", type=Path)

    num_cores = psutil.cpu_count(logical=False)
    parser.add_argument(
        "--workers",
        dest="workers",
        type=int,
        default=num_cores,
        help="number of workers for parallel execution [def: %(default)s]",
    )

    parser.add_argument(
        "--source_dim",
        dest="source_dim",
        type=int,
        default=2048,
        help="size of input tiles [def: %(default)s]",
    )

    parser.add_argument(
        "--tile_size",
        dest="tile_size",
        type=int,
        default=256,
        help="size of final tiles that are then passed to the model [def: %(default)s]",
    )

    parser.add_argument(
        "--format",
        dest="format",
        type=str,
        default="TIFF",
        choices=["PNG", "TIFF"],
        help="target file format (PNG, TIFF) [def: %(default)s]",
    )

    parser.add_argument(
        "--tmp-dir",
        dest="tmp_dir",
        type=Path,
        default=None,
        help="use this location as tmp dir",
    )

    parser.add_argument(
        "--subdir",
        dest="sub_dir",
        default="train",
        help="use this location as sub_dir",
    )

    parser.add_argument(
        "--stats",
        dest="stats_file",
        type=Path,
        default=Path("stats.csv"),
        help="use this file to record stats",
    )

    args = parser.parse_args()

    args.outdir.mkdir(parents=True, exist_ok=True)
    Path(args.outdir / args.sub_dir).mkdir(parents=True, exist_ok=True)

    if args.tmp_dir:
        print(f"Using custom tmp dir: {args.tmp_dir}")
        Path(args.tmp_dir).mkdir(parents=True, exist_ok=True)

    if args.format == "TIFF":
        suffix = "tif"
    elif args.format == "PNG":
        suffix = "png"
    else:
        raise NotImplementedError

    SHUFFLE = True  # shuffle subtile order within shards (with fixed seed)

    # subtile_stats = split_tiles(train_files)
    images = sorted(args.image_dir.glob("*.tif"))
    masks = sorted(args.mask_dir.glob("*.tif"))
    lus = sorted(args.lu_dir.glob("*.tif"))

    image_names = {i.name for i in images}
    mask_names = {i.name for i in masks}
    lu_names = {i.name for i in lus}

    # limit set of images to images that have equivalent mask tiles
    train_images = [
        i
        for i in images
        if i.name in image_names.intersection(mask_names).intersection(lu_names)
    ]
    train_masks = [
        i
        for i in masks
        if i.name in mask_names.intersection(image_names).intersection(lu_names)
    ]
    train_lus = [
        i
        for i in lus
        if i.name in lu_names.intersection(mask_names).intersection(image_names)
    ]

    train_images = sorted(train_images)
    train_masks = sorted(train_masks)
    train_lus = sorted(train_lus)

    # print(len(train_images))
    # print(len(train_masks))
    # exit()
    # print(len(train_lus))

    cfg = dict(
        source_dim=args.source_dim,
        tile_size=args.tile_size,
        format=args.format,
    )

    subtile_stats = split_tiles(
        train_images,
        train_masks,
        train_lus,
        args.workers,
        str(args.outdir / args.sub_dir / "train-%06d.tar"),
        **cfg,
    )

    with open(args.outdir / args.stats_file, "w") as fout:
        fout.write("tile,frac,status\n")
        for i, (fname, frac, status) in enumerate(subtile_stats):
            line = f"{fname},{frac},{status}\n"
            fout.write(line)

    # rebalance shards so we get similar distributions in all shards
    with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmpdir:
        print(f"Created a temporary directory: {tmpdir}")

        print("Extract source tars")
        # untar input
        for tf_name in sorted((args.outdir / args.sub_dir).glob("train-00*.tar")):
            with tarfile.open(tf_name) as tf:
                tf.extractall(tmpdir)

        print("Write balanced shards from deadtree samples")
        df = pd.read_csv(args.outdir / args.stats_file)

        df = df[df.status > 0]

        n_valid = len(df)
        splits = split_df(df, SHARDSIZE)

        # preserve last shard if more than 50% of values are present
        if SHARDSIZE // 2 < len(splits[-1]) < SHARDSIZE:
            # fill last shard with duplicates (not ideal...)
            n_missing = SHARDSIZE - len(splits[-1])
            # df_extra = splits[-1].sample(n=n_missing, random_state=42)
            splits[-1].extend(np.random.choice(splits[-1], size=n_missing).tolist())

        # drop incomplete shards
        splits = [x for x in splits if len(x) == SHARDSIZE]
        assert len(splits) > 0, "Something went wrong"

        for s_cnt, s in enumerate(splits):

            with tarfile.open(
                args.outdir / args.sub_dir / f"train-balanced-{s_cnt:06}.tar", "w"
            ) as dst:

                if SHUFFLE:
                    random.shuffle(s)
                for i in s:
                    dst.add(f"{tmpdir}/{i}.mask.{suffix}", f"{i}.mask.{suffix}")
                    dst.add(f"{tmpdir}/{i}.lu.{suffix}", f"{i}.lu.{suffix}")
                    dst.add(f"{tmpdir}/{i}.rgbn.{suffix}", f"{i}.rgbn.{suffix}")
                    dst.add(f"{tmpdir}/{i}.txt", f"{i}.txt")

    # create sets for random tile dataset
    # use all subtiles not covered in train

    n_subtiles = (args.source_dim // args.tile_size) ** 2

    all_subtiles = []
    for image_name in image_names:
        all_subtiles.extend(
            [f"{Path(image_name).stem}_{c:03}" for c in range(n_subtiles)]
        )
    all_subtiles = set(all_subtiles)

    n_samples = n_valid * OVERSAMPLE_FACTOR
    random_subtiles = random.sample(
        tuple(all_subtiles - set([x[0] for x in subtile_stats if int(x[2]) == 1])),
        n_samples,
    )

    # the necessary tile to process
    random_tiles = sorted(list(set([x[:-4] for x in random_subtiles])))

    all_images = sorted(args.image_dir.glob("*.tif"))
    random_images = [x for x in all_images if x.stem in random_tiles]

    print("STATS")
    print(len(all_subtiles))
    print(len(subtile_stats))
    print(len(random_subtiles))
    print(len(random_images))

    cfg = dict(
        source_dim=args.source_dim,
        tile_size=args.tile_size,
        format=args.format,
        valid_subtiles=random_subtiles,  # subset data with random selection of subtiles
    )

    random_images_names = {i.name for i in random_images}
    random_lus = [i for i in lus if i.name in random_images_names]

    subtile_stats_rnd = split_tiles(
        random_images,
        [None] * len(random_images),
        random_lus,
        args.workers,
        str(args.outdir / args.sub_dir / "train-randomsamples-%06d.tar"),
        **cfg,
    )

    stats_file_rnd = Path(args.stats_file.stem + "_rnd.csv")
    with open(args.outdir / stats_file_rnd, "w") as fout:
        fout.write("tile,frac,status\n")
        for i, (fname, frac, status) in enumerate(subtile_stats_rnd):
            line = f"{fname},{frac},{status}\n"
            fout.write(line)

    # also create combo dataset
    # source A: train-balanced, source B: randomsample
    # NOTE: combo dataset has double the default shardsize (2*128), samples alternate between regular and random sample
    train_balanced_shards = [
        str(x) for x in sorted((args.outdir / args.sub_dir).glob("train-balanced*"))
    ]
    train_balanced_shards_rnd = [
        str(x) for x in sorted((args.outdir / args.sub_dir).glob("train-random*"))
    ]
    train_balanced_shards_rnd = train_balanced_shards_rnd[: len(train_balanced_shards)]

    shardpattern = str(args.outdir / args.sub_dir / "train-combo-%06d.tar")

    with wds.ShardWriter(shardpattern, maxcount=SHARDSIZE * 2) as sink:
        for shardA, shardB in zip(train_balanced_shards, train_balanced_shards_rnd):

            for sA, sB in zip(wds.WebDataset(shardA), wds.WebDataset(shardB)):
                sink.write(sA)
                sink.write(sB)

    # remove everything but train & combo
    for filename in (args.outdir / args.sub_dir).glob("train-random*"):
        filename.unlink()
    for filename in (args.outdir / args.sub_dir).glob("train-balanced*"):
        filename.unlink()
    for filename in (args.outdir / args.sub_dir).glob("train-0*"):
        filename.unlink()