def test_partitioned_totals_approx_equal(self): # dodgy, hand-crafted, and should be replaced by something rigid splits = split_df(self.data_fake, 10) totals = [ self.data_fake[self.data_fake.tile.isin(s)].frac.sum() for s in splits ] assert [45] * len(totals) == pytest.approx(totals, abs=5)
def test_number_of_partitions_as_requested(self): result = split_df(self.data, 3) assert len(result) == math.ceil(len(self.data) / 3)
def test_total_size_unchanged(self): result = split_df(self.data, 3) assert len(reduce(lambda z, y: z + y, result)) == len(self.data)
def test_catch_tiles_without_deadtrees(self): with pytest.raises(ValueError): split_df(self.data_bad, 3)
def test_catch_invalid_size(self, n): with pytest.raises(ValueError): split_df(self.data, n)
def main(): parser = argparse.ArgumentParser() parser.add_argument("image_dir", type=Path) parser.add_argument("mask_dir", type=Path) parser.add_argument("lu_dir", type=Path) parser.add_argument("outdir", type=Path) num_cores = psutil.cpu_count(logical=False) parser.add_argument( "--workers", dest="workers", type=int, default=num_cores, help="number of workers for parallel execution [def: %(default)s]", ) parser.add_argument( "--source_dim", dest="source_dim", type=int, default=2048, help="size of input tiles [def: %(default)s]", ) parser.add_argument( "--tile_size", dest="tile_size", type=int, default=256, help="size of final tiles that are then passed to the model [def: %(default)s]", ) parser.add_argument( "--format", dest="format", type=str, default="TIFF", choices=["PNG", "TIFF"], help="target file format (PNG, TIFF) [def: %(default)s]", ) parser.add_argument( "--tmp-dir", dest="tmp_dir", type=Path, default=None, help="use this location as tmp dir", ) parser.add_argument( "--subdir", dest="sub_dir", default="train", help="use this location as sub_dir", ) parser.add_argument( "--stats", dest="stats_file", type=Path, default=Path("stats.csv"), help="use this file to record stats", ) args = parser.parse_args() args.outdir.mkdir(parents=True, exist_ok=True) Path(args.outdir / args.sub_dir).mkdir(parents=True, exist_ok=True) if args.tmp_dir: print(f"Using custom tmp dir: {args.tmp_dir}") Path(args.tmp_dir).mkdir(parents=True, exist_ok=True) if args.format == "TIFF": suffix = "tif" elif args.format == "PNG": suffix = "png" else: raise NotImplementedError SHUFFLE = True # shuffle subtile order within shards (with fixed seed) # subtile_stats = split_tiles(train_files) images = sorted(args.image_dir.glob("*.tif")) masks = sorted(args.mask_dir.glob("*.tif")) lus = sorted(args.lu_dir.glob("*.tif")) image_names = {i.name for i in images} mask_names = {i.name for i in masks} lu_names = {i.name for i in lus} # limit set of images to images that have equivalent mask tiles train_images = [ i for i in images if i.name in image_names.intersection(mask_names).intersection(lu_names) ] train_masks = [ i for i in masks if i.name in mask_names.intersection(image_names).intersection(lu_names) ] train_lus = [ i for i in lus if i.name in lu_names.intersection(mask_names).intersection(image_names) ] train_images = sorted(train_images) train_masks = sorted(train_masks) train_lus = sorted(train_lus) # print(len(train_images)) # print(len(train_masks)) # exit() # print(len(train_lus)) cfg = dict( source_dim=args.source_dim, tile_size=args.tile_size, format=args.format, ) subtile_stats = split_tiles( train_images, train_masks, train_lus, args.workers, str(args.outdir / args.sub_dir / "train-%06d.tar"), **cfg, ) with open(args.outdir / args.stats_file, "w") as fout: fout.write("tile,frac,status\n") for i, (fname, frac, status) in enumerate(subtile_stats): line = f"{fname},{frac},{status}\n" fout.write(line) # rebalance shards so we get similar distributions in all shards with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmpdir: print(f"Created a temporary directory: {tmpdir}") print("Extract source tars") # untar input for tf_name in sorted((args.outdir / args.sub_dir).glob("train-00*.tar")): with tarfile.open(tf_name) as tf: tf.extractall(tmpdir) print("Write balanced shards from deadtree samples") df = pd.read_csv(args.outdir / args.stats_file) df = df[df.status > 0] n_valid = len(df) splits = split_df(df, SHARDSIZE) # preserve last shard if more than 50% of values are present if SHARDSIZE // 2 < len(splits[-1]) < SHARDSIZE: # fill last shard with duplicates (not ideal...) n_missing = SHARDSIZE - len(splits[-1]) # df_extra = splits[-1].sample(n=n_missing, random_state=42) splits[-1].extend(np.random.choice(splits[-1], size=n_missing).tolist()) # drop incomplete shards splits = [x for x in splits if len(x) == SHARDSIZE] assert len(splits) > 0, "Something went wrong" for s_cnt, s in enumerate(splits): with tarfile.open( args.outdir / args.sub_dir / f"train-balanced-{s_cnt:06}.tar", "w" ) as dst: if SHUFFLE: random.shuffle(s) for i in s: dst.add(f"{tmpdir}/{i}.mask.{suffix}", f"{i}.mask.{suffix}") dst.add(f"{tmpdir}/{i}.lu.{suffix}", f"{i}.lu.{suffix}") dst.add(f"{tmpdir}/{i}.rgbn.{suffix}", f"{i}.rgbn.{suffix}") dst.add(f"{tmpdir}/{i}.txt", f"{i}.txt") # create sets for random tile dataset # use all subtiles not covered in train n_subtiles = (args.source_dim // args.tile_size) ** 2 all_subtiles = [] for image_name in image_names: all_subtiles.extend( [f"{Path(image_name).stem}_{c:03}" for c in range(n_subtiles)] ) all_subtiles = set(all_subtiles) n_samples = n_valid * OVERSAMPLE_FACTOR random_subtiles = random.sample( tuple(all_subtiles - set([x[0] for x in subtile_stats if int(x[2]) == 1])), n_samples, ) # the necessary tile to process random_tiles = sorted(list(set([x[:-4] for x in random_subtiles]))) all_images = sorted(args.image_dir.glob("*.tif")) random_images = [x for x in all_images if x.stem in random_tiles] print("STATS") print(len(all_subtiles)) print(len(subtile_stats)) print(len(random_subtiles)) print(len(random_images)) cfg = dict( source_dim=args.source_dim, tile_size=args.tile_size, format=args.format, valid_subtiles=random_subtiles, # subset data with random selection of subtiles ) random_images_names = {i.name for i in random_images} random_lus = [i for i in lus if i.name in random_images_names] subtile_stats_rnd = split_tiles( random_images, [None] * len(random_images), random_lus, args.workers, str(args.outdir / args.sub_dir / "train-randomsamples-%06d.tar"), **cfg, ) stats_file_rnd = Path(args.stats_file.stem + "_rnd.csv") with open(args.outdir / stats_file_rnd, "w") as fout: fout.write("tile,frac,status\n") for i, (fname, frac, status) in enumerate(subtile_stats_rnd): line = f"{fname},{frac},{status}\n" fout.write(line) # also create combo dataset # source A: train-balanced, source B: randomsample # NOTE: combo dataset has double the default shardsize (2*128), samples alternate between regular and random sample train_balanced_shards = [ str(x) for x in sorted((args.outdir / args.sub_dir).glob("train-balanced*")) ] train_balanced_shards_rnd = [ str(x) for x in sorted((args.outdir / args.sub_dir).glob("train-random*")) ] train_balanced_shards_rnd = train_balanced_shards_rnd[: len(train_balanced_shards)] shardpattern = str(args.outdir / args.sub_dir / "train-combo-%06d.tar") with wds.ShardWriter(shardpattern, maxcount=SHARDSIZE * 2) as sink: for shardA, shardB in zip(train_balanced_shards, train_balanced_shards_rnd): for sA, sB in zip(wds.WebDataset(shardA), wds.WebDataset(shardB)): sink.write(sA) sink.write(sB) # remove everything but train & combo for filename in (args.outdir / args.sub_dir).glob("train-random*"): filename.unlink() for filename in (args.outdir / args.sub_dir).glob("train-balanced*"): filename.unlink() for filename in (args.outdir / args.sub_dir).glob("train-0*"): filename.unlink()