def from_processed(url: str, train=False): urls = sorted(glob.glob(url)) if train: return ( wds.WebDataset(urls) .shuffle(size=10000000, initial=100000) .decode() .map(lambda d: Example(**d["json"])) ) else: return list(wds.WebDataset(urls).decode().map(lambda d: Example(**d["json"])))
def build_dataset( shards: List[str], bs: int, transform_func: Callable, shuffle: Optional[int] = 128, shard_size: Optional[int] = 128, ) -> wds.WebDataset: return ( wds.WebDataset( shards, length=len(shards) * shard_size // bs, ) .shuffle(shuffle) .map(sample_decoder) .rename(image="rgbn.tif", mask="mask.tif", lu="lu.tif", stats="txt") .map( partial( transform, transform_func=transform_func, in_channels=in_channels, classes=classes, distmap=True, ) ) .to_tuple("image", "mask", "distmap", "lu", "stats") )
def display_shard_images(client, bucket, tar_name, objects=2, etl_id=""): to_tensor = transforms.Compose([transforms.ToTensor()]) test_object = ( wds.WebDataset( client.object_url(bucket, tar_name, transform_id=etl_id), handler=wds.handlers.warn_and_continue, ).decode("rgb").to_tuple("jpg;png;jpeg;npy cls", handler=wds.handlers.warn_and_continue).map_tuple(to_tensor, lambda x: x) ) test_loader = wds.WebLoader( test_object, batch_size=None, shuffle=False, num_workers=1, ) test_iter = iter(test_loader) row = 0 _, axarr = plt.subplots((objects // 2), 2, figsize=(12, 12)) for i in range(objects): column = i % 2 img_tensor, _ = next(test_iter) plt.figure() img = np.transpose(np.asarray(img_tensor.squeeze()), (1, 2, 0)) img = np.clip(img, 0, 1) axarr[row, column].set_yticks([]) axarr[row, column].set_xticks([]) axarr[row, column].imshow(img, interpolation="nearest") if column == 1: row += 1 plt.show()
def make_train_loader_wds(args): print("=> using WebDataset loader") train_transform = make_train_transform(args) num_batches = args.trainsize // args.batch_size train_dataset = (wds.WebDataset( args.trainshards, length=num_batches).shuffle( args.shuffle).decode("pil").to_tuple("jpg;png;jpeg cls").map_tuple( train_transform, identity)) if args.distributed: # It's good to avoid partial batches when using DistributedDataParallel. train_dataset = train_dataset.batched(args.batch_size, partial=False) else: train_dataset = train_dataset.batched(args.batch_size) # WebLoader is just the regular DataLoader with the same convenience methods # that WebDataset has. train_loader = wds.WebLoader( train_dataset, batch_size=None, shuffle=False, num_workers=args.workers, ) if args.distributed: # With DDP, we need to make sure that all nodes get the same number of batches; # we do that by reusing a little bit of data. # Note that you only need to do this when retrofitting code that depends on # epoch size. A better way is to iterate through the entire dataset on all nodes. dataset_size = 1281167 number_of_batches = dataset_size // (args.batch_size * args.world_size) print("# batches per node = ", number_of_batches) train_loader = train_loader.repeat(2).slice(number_of_batches) # This only sets the value returned by the len() function; nothing else uses it, # but some frameworks care about it. train_loader.length = number_of_batches return train_loader
def __init__(self, manifest_path: str, tar_filepaths: Union[str, List[str]], shuffle_n: int = 128): self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True) if isinstance(tar_filepaths, str): # Replace '(' and '[' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in tar_filepaths: tar_filepaths = tar_filepaths.replace(bkey, "{") # Replace ')' and ']' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in tar_filepaths: tar_filepaths = tar_filepaths.replace(bkey, "}") if not HAVE_OMEGACONG_WEBDATASET: raise LightningNotInstalledException(self) self.audio_dataset = wd.WebDataset(urls=tar_filepaths, nodesplitter=None) if shuffle_n > 0: self.audio_dataset = self.audio_dataset.shuffle(shuffle_n) else: logging.info("WebDataset will not shuffle files within the tar files.") self.audio_dataset = self.audio_dataset.rename(audio='wav', key='__key__').to_tuple('audio', 'key') self.audio_iter = iter(self.audio_dataset)
def __init__( self, *, audio_tar_filepaths: Union[str, List[str]], manifest_filepath: Union[str, List[str]], labels: List[str], featurizer, shuffle_n: int = 0, min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, is_regression_task: bool = False, ): self.collection = collections.ASRSpeechLabel( manifests_files=manifest_filepath, min_duration=min_duration, max_duration=max_duration, index_by_file_id=True, # Must set this so the manifest lines can be indexed by file ID ) self.file_occurence = count_occurence(self.collection.mapping) self.featurizer = featurizer self.trim = trim self.labels = labels if labels else self.collection.uniq_labels self.num_classes = len(self.labels) self.label2id, self.id2label = {}, {} for label_id, label in enumerate(self.labels): self.label2id[label] = label_id self.id2label[label_id] = label for idx in range(len(self.labels[:5])): logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx])) audio_tar_filepaths = expand_audio_filepaths( audio_tar_filepaths=audio_tar_filepaths, shard_strategy=shard_strategy, world_size=world_size, global_rank=global_rank, ) # Put together WebDataset self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None) if shuffle_n > 0: self._dataset = self._dataset.shuffle(shuffle_n) else: logging.info("WebDataset will not shuffle files within the tar files.") self._dataset = ( self._dataset.rename(audio=VALID_FILE_FORMATS, key='__key__') .to_tuple('audio', 'key') .pipe(self._filter) .map(f=self._build_sample) )
def setup(self, stage=None): """Downloads the data, parse it and split the data into train, test, validation data. Args: stage: Stage - training or testing """ data_path = self.args.get("train_glob", "/pvc/output/processing") train_base_url = data_path + "/train" val_base_url = data_path + "/val" test_base_url = data_path + "/test" train_count = self.get_num_files(train_base_url) val_count = self.get_num_files(val_base_url) test_count = self.get_num_files(test_base_url) train_url = "{}/{}-{}".format(train_base_url, "train", "{0.." + str(train_count) + "}.tar") valid_url = "{}/{}-{}".format(val_base_url, "val", "{0.." + str(val_count) + "}.tar") test_url = "{}/{}-{}".format(test_base_url, "test", "{0.." + str(test_count) + "}.tar") self.train_dataset = (wds.WebDataset( train_url, handler=wds.warn_and_continue).shuffle(100).decode("pil").rename( image="ppm;jpg;jpeg;png", info="cls").map_dict(image=self.train_transform).to_tuple( "image", "info").batched(40)) self.valid_dataset = (wds.WebDataset( valid_url, handler=wds.warn_and_continue).shuffle(100).decode("pil").rename( image="ppm", info="cls").map_dict(image=self.valid_transform).to_tuple( "image", "info").batched(20)) self.test_dataset = (wds.WebDataset( test_url, handler=wds.warn_and_continue).shuffle(100).decode("pil").rename( image="ppm", info="cls").map_dict(image=self.valid_transform).to_tuple( "image", "info").batched(20))
def objective(trial: optuna.trial.Trial) -> float: dataset = wds.WebDataset("/run/media/jacob/data/FACT_Dataset/fact-gamma-10-{0000..0062}.tar").shuffle(20000).decode() dataset_2 = wds.WebDataset("/run/media/jacob/data/FACT_Dataset/fact-proton-10-{0000..0010}.tar").shuffle(20000).decode() test_dataset_2 = wds.WebDataset("/run/media/jacob/data/FACT_Dataset/fact-gamma-10-{0063..0072}.tar").decode() test_dataset = wds.WebDataset("/run/media/jacob/data/FACT_Dataset/fact-proton-10-{0011..0013}.tar").decode() dataset = SampleEqually([dataset, dataset_2]) test_dataset = SampleEqually([test_dataset_2, test_dataset]) train_loader = DataLoader(dataset, num_workers=16, batch_size=4, pin_memory=True) test_loader = DataLoader(test_dataset, num_workers=4, batch_size=1, pin_memory=True) # We optimize the number of layers, hidden units in each layer and dropouts. config = { "sample_ratio_one": trial.suggest_uniform("sample_ratio_one", 0.1, 0.9), "sample_radius_one": trial.suggest_uniform("sample_radius_one", 0.1, 0.9), "sample_max_neighbor": trial.suggest_int("sample_max_neighbor", 8, 72), "sample_ratio_two": trial.suggest_uniform("sample_ratio_two", 0.1, 0.9), "sample_radius_two": trial.suggest_uniform("sample_radius_two", 0.1, 0.9), "fc_1": trial.suggest_int("fc_1", 128, 256), "fc_1_out": trial.suggest_int("fc_1_out", 32, 128), "fc_2_out": trial.suggest_int("fc_2_out", 16, 96), "dropout": trial.suggest_uniform("dropout", 0.1, 0.9), } num_classes = 2 import pytorch_lightning as pl model = LitPointNet2(num_classes, lr=0.0001, config=config) trainer = pl.Trainer( logger=True, limit_val_batches=10000, limit_train_batches=10000, checkpoint_callback=False, auto_lr_find=True, max_epochs=20, gpus=1, callbacks=[PyTorchLightningPruningCallback(trial, monitor="val/loss")], ) trainer.logger.log_hyperparams(config) trainer.tune(model=model, train_dataloader=train_loader, val_dataloaders=test_loader) trainer.fit(model=model, train_dataloader=train_loader, val_dataloaders=test_loader) return trainer.callback_metrics["val/loss"].item()
def loader(urls, batch_size, workers): to_tensor = transforms.Compose([transforms.ToTensor()]) etl_dataset = (wds.WebDataset( urls, handler=wds.handlers.warn_and_continue).decode("rgb").to_tuple( "npy cls", handler=wds.handlers.warn_and_continue).map_tuple( to_tensor, lambda x: x)) ds_size = (500 * len(urls)) // batch_size etl_dataset = etl_dataset.with_length(ds_size) loader = wds.WebLoader( etl_dataset, batch_size=batch_size, num_workers=workers, ) return loader.with_length(ds_size)
def __init__(self, urls, length, shuffle_buffer, nodesplitter=None, memory_cache=None): super().__init__() self.memory_cache = memory_cache self.dataset = wds.WebDataset( urls, shardshuffle=True if shuffle_buffer > 1 else False, length=length, nodesplitter=nodesplitter) if shuffle_buffer > 1: self.dataset = self.dataset.shuffle(shuffle_buffer)
matplotlib.use('Agg') all_tars = [] model = GAN() if torch.cuda.is_available(): decods = my_decoders(128) for root, dirs, files in os.walk("."): for file in files: if file.endswith( ".tar") and "out" not in root and "out" not in file: all_tars.append(os.path.join(root, file)) dataset = wds.WebDataset(all_tars, length=float("inf")) \ .decode(decods.simple_decoder).to_tuple("gt.jpg", "__key__", handler=dummy_func).batched(16) dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=20, collate_fn=coll) trainer = pl.Trainer( gpus=1, log_every_n_steps=10, max_epochs=10, profiler=False, precision=16, distributed_backend='ddp') #, logger=neptune_logger) else: decods = my_decoders(128)
def __init__( self, text_tar_filepaths: str, num_batches: int, shuffle_n: int = 0, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 1, ): super(TarredTextNormalizationDecoderDataset, self).__init__() valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") if isinstance(text_tar_filepaths, str): # Replace '(', '[', '<' and '_OP_' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "{") # Replace ')', ']', '>' and '_CL_' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "}") if isinstance(text_tar_filepaths, str): # Brace expand text_tar_filepaths = list( braceexpand.braceexpand(text_tar_filepaths)) if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(text_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(text_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(text_tar_filepaths) // world_size) logging.info('Begin Index : %d' % (begin_idx)) logging.info('End Index : %d' % (end_idx)) text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy! Allowed values are: {valid_shard_strategies}" ) # Put together WebDataset self._dataset = wd.WebDataset(urls=text_tar_filepaths, nodesplitter=None) self.length = num_batches // world_size if shuffle_n > 0: self._dataset = self._dataset.shuffle(shuffle_n) else: logging.info( "WebDataset will not shuffle files within the tar files.") self._dataset = self._dataset.rename( pkl='pkl', key='__key__').to_tuple('pkl', 'key').map(f=self._build_sample)
if ENABLE_WEBDATASET: DATASET_SIZE = int( 1e9 ) # You need to set a nominal length for the Dataset in order to avoid warnings from DataLoader myimg, mycap = WEBDATASET_IMAGE_TEXT_COLUMNS image_text_mapping = {myimg: imagetransform, mycap: tokenize} image_mapping = {myimg: imagepreproc} num_batches = DATASET_SIZE // BATCH_SIZE ds = ( wds.WebDataset(DATASET, length=num_batches) # .shuffle(is_shuffle) # Commented out for WebDataset as the behaviour cannot be predicted yet .map_dict(**image_text_mapping).map_dict(**image_mapping). to_tuple(mycap, myimg).batched( BATCH_SIZE, partial=False ) # It is good to avoid partial batches when using Distributed training ) else: ds = TextImageDataset( args.image_text_folder, text_len=TEXT_SEQ_LEN, image_size=IMAGE_SIZE, resize_ratio=args.resize_ratio, truncate_captions=args.truncate_captions, tokenizer=tokenizer, shuffle=is_shuffle, ) assert len(ds) > 0, 'dataset is empty'
image_text_mapping = { myimg: imagetransform, mycap: tokenize } image_mapping = { myimg: imagepreproc } def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available. if mycap not in item: return False if myimg not in item: return False return True w_dataset = wds.WebDataset(DATASET, handler=wds.warn_and_continue) filtered_dataset = w_dataset.select(filter_dataset) ds = filtered_dataset.map_dict(**image_text_mapping).map_dict(**image_mapping).to_tuple(mycap, myimg).batched(BATCH_SIZE, partial=True) else: ds = TextImageDataset( args.image_text_folder, text_len=TEXT_SEQ_LEN, image_size=IMAGE_SIZE, resize_ratio=args.resize_ratio, truncate_captions=args.truncate_captions, tokenizer=tokenizer, shuffle=is_shuffle, ) assert len(ds) > 0, 'dataset is empty' if distr_backend.is_root_worker():
def dataio_prep_shards(hparams): # load the meta info json file with wds.gopen.gopen(hparams["train_meta"], "rb") as f: train_meta = json.load(f) with wds.gopen.gopen(hparams["val_meta"], "rb") as f: val_meta = json.load(f) # define the mapping functions in the data pipeline snt_len_sample = int(hparams["sample_rate"] * hparams["sentence_len"]) label_encoder = sb.dataio.encoder.CategoricalEncoder() lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") label_encoder.load_or_create( path=lab_enc_file, from_iterables=[train_meta["language_ids"]], output_key="lang_id", ) # breakpoint() def audio_pipeline(sample_dict: Dict, random_chunk=True): key = sample_dict["__key__"] language_id = sample_dict["language_id"].decode("ascii") audio_tensor = sample_dict["audio.pth"] # determine what part of audio sample to use audio_tensor = audio_tensor.squeeze() if random_chunk: if len(audio_tensor) - snt_len_sample - 1 <= 0: start = 0 else: start = random.randint(0, len(audio_tensor) - snt_len_sample - 1) stop = start + snt_len_sample else: start = 0 stop = len(audio_tensor) sig = audio_tensor[start:stop] # determine the language ID of the sample lang_id_idx = label_encoder.encode_label(language_id) return { "sig": sig, "lang_id_encoded": lang_id_idx, "id": key, } train_data = (wds.WebDataset( hparams["train_shards"], cache_dir=hparams["shard_cache_dir"], ).repeat().shuffle(1000).decode("pil").map( partial(audio_pipeline, random_chunk=True))) logger.info( f"Training data consist of {train_meta['num_data_samples']} samples") valid_data = (wds.WebDataset( hparams["val_shards"], cache_dir=hparams["shard_cache_dir"], ).decode("pil").map(partial(audio_pipeline, random_chunk=False))) logger.info( f"Validation data consist of {val_meta['num_data_samples']} samples") return ( train_data, valid_data, train_meta["num_data_samples"], val_meta["num_data_samples"], )
def __init__( self, audio_tar_filepaths: Union[str, List[str]], manifest_filepath: str, parser: Callable, sample_rate: int, int_values: bool = False, augmentor: Optional[ 'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None, shuffle_n: int = 0, min_duration: Optional[float] = None, max_duration: Optional[float] = None, max_utts: int = 0, trim: bool = False, bos_id: Optional[int] = None, eos_id: Optional[int] = None, pad_id: int = 0, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, ): self.collection = collections.ASRAudioText( manifests_files=manifest_filepath.split(','), parser=parser, min_duration=min_duration, max_duration=max_duration, max_number=max_utts, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim self.eos_id = eos_id self.bos_id = bos_id self.pad_id = pad_id valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") if isinstance(audio_tar_filepaths, str): # Replace '(' and '[' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "{") # Replace ')' and ']' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "}") # Check for distributed and partition shards accordingly if world_size > 1: if isinstance(audio_tar_filepaths, str): # Brace expand audio_tar_filepaths = list( braceexpand.braceexpand(audio_tar_filepaths)) if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(audio_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(audio_tar_filepaths) // world_size) audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}" ) # Put together WebDataset self._dataset = wd.WebDataset(audio_tar_filepaths) if shuffle_n > 0: self._dataset = self._dataset.shuffle(shuffle_n) else: logging.info( "WebDataset will not shuffle files within the tar files.") self._dataset = (self._dataset.rename( audio='wav', key='__key__').to_tuple('audio', 'key').pipe( self._filter).map(f=self._build_sample))
def decode_to_torch(sample): from torch_geometric.data import Data import torch points = sample["points.pth"] mask = sample["mask.pth"] is_gamma = sample["class.cls"] result = Data(pos=points, y=mask) # Just need x,y,z ignore derived features return result if __name__ == "__main__": args = default_argument_parser().parse_args() dataset = wds.WebDataset( "/run/media/jacob/data/FACT_Dataset/fact-train-10-{0000..0040}.tar" ).shuffle(2000).decode() test_dataset = wds.WebDataset( "/run/media/jacob/data/FACT_Dataset/fact-test-5-{0000..0017}.tar" ).decode() dataset = wds.Processor(dataset, wds.map, decode_to_torch) test_dataset = wds.Processor(test_dataset, wds.map, decode_to_torch) train_loader = DataLoader(dataset, num_workers=12, batch_size=1, pin_memory=True) test_loader = DataLoader(test_dataset, num_workers=1, batch_size=1, pin_memory=True)
def __init__( self, text_tar_filepaths: str, metadata_path: str, tokenizer, max_seq_length: int = 512, batch_step: int = None, shuffle_n: int = 1, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, ): super(TarredL2RLanguageModelingDataset, self).__init__() self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.batch_step = batch_step or self.max_seq_length valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}") with open(metadata_path, 'r') as f: metadata = json.load(f) self.metadata = metadata if isinstance(text_tar_filepaths, str): # Replace '(', '[', '<' and '_OP_' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "{") # Replace ')', ']', '>' and '_CL_' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "}") if isinstance(text_tar_filepaths, str): # Brace expand text_tar_filepaths = list(braceexpand.braceexpand(text_tar_filepaths)) if shard_strategy == 'scatter': logging.info("All tarred dataset shards will be scattered evenly across all nodes.") if len(text_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size})." ) begin_idx = (len(text_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(text_tar_filepaths) // world_size) text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx ) elif shard_strategy == 'replicate': logging.info("All tarred dataset shards will be replicated across all nodes.") else: raise ValueError(f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}") self.tarpath = text_tar_filepaths # Put together WebDataset self._dataset = wd.WebDataset(text_tar_filepaths) if shuffle_n > 0: self._dataset = self._dataset.shuffle(shuffle_n) else: logging.info("WebDataset will not shuffle files within the tar files.") self._dataset = self._dataset.rename(npy='npy', key='__key__').to_tuple('npy', 'key').map(f=self._build_sample)
def main(): parser = argparse.ArgumentParser() parser.add_argument("datapath", type=Path, nargs="+") parser.add_argument( "--frac", dest="frac", type=float, default=1.0, help="fraction of tiles to consider [range: 0-1, def: %(default)s]", ) args = parser.parse_args() np.random.seed(42) print("Using fixed random seed!") # constants tile_size = 256 size = tile_size**2 if isinstance(args.datapath, list): tar_files = sorted( list(itertools.chain(*[x.glob("*.tar") for x in args.datapath]))) tif_files = sorted( list(itertools.chain(*[x.glob("*.tif") for x in args.datapath]))) else: tar_files = sorted(args.datapath.glob("*.tar")) tif_files = sorted(args.datapath.glob("*.tif")) n_files = len(tif_files) SUBSET = int(round(args.frac * n_files, 0)) selection = np.random.choice(range(n_files), size=SUBSET, replace=False) if len(tar_files) > len(tif_files): # webdataset dataset = (wds.WebDataset( [str(x) for x in tar_files]).map(sample_decoder).rename( image="rgbn.tif", mask="mask.tif", stats="txt").map_dict(image=transform).to_tuple("image")) else: # plain source tif dataset dataset = TifDataset(tif_files, transform=transform) dataset = Subset(dataset, selection) dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False) mean, std = np.zeros(4), np.zeros(4) print("\nCalculating STATS") print("\nCalculating MEAN") cnt = 0 for i, data in enumerate(tqdm(dataloader)): data = data.squeeze(0).numpy() # ignore incomplete tiles for stats if data.shape[-2] != data.shape[-1]: continue # check for empty tile and skip (all values are either 0 or 1 in the first band): if np.isin(data, [0, 1]).all(): continue subtiles_rgbn = make_blocks_vectorized(data, tile_size) for subtile_rgbn in subtiles_rgbn: if subtile_rgbn[0].min() != subtile_rgbn[0].max(): mean += subtile_rgbn.sum((1, 2)) / size cnt += 1 mean /= cnt + 1 # i + 1 mean_unsqueezed = np.expand_dims(np.expand_dims(mean, 1), 2) # mean.unsqueeze(1).unsqueeze(2) print("\nCalculating STD") cnt = 0 for i, data in enumerate(tqdm(dataloader)): data = data.squeeze(0).numpy() # ignore incomplete tiles for stats if data.shape[-2] != data.shape[-1]: continue # check for empty tile and skip (all values are either 0 or 1 in the first band): if np.isin(data, [0, 1]).all(): continue subtiles_rgbn = make_blocks_vectorized(data, tile_size) for subtile_rgbn in subtiles_rgbn: if subtile_rgbn[0].min() != subtile_rgbn[0].max(): std += ((subtile_rgbn - mean_unsqueezed)**2).sum((1, 2)) / size cnt += 1 std /= cnt + 1 std = np.sqrt(std) # std.sqrt() df = pd.DataFrame({ "band": ["red", "green", "blue", "nir"], "mean": mean.tolist(), "std": std.tolist(), }) df = df.set_index("band") # report info = { "sources": [str(x) for x in args.datapath], "date": str(datetime.datetime.now()), "frac": args.frac, "subtiles": cnt, "results": json.loads(df.to_json(orient="index")), } # Serializing json with open(args.datapath[0].parent / "processed.images.stats.json", "w") as fout: fout.write(json.dumps(info, indent=4))
def __init__( self, *, audio_tar_filepaths: Union[str, List[str]], manifest_filepath: str, labels: List[str], featurizer, shuffle_n: int = 0, min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, load_audio: bool = True, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, ): self.collection = collections.ASRSpeechLabel( manifests_files=manifest_filepath.split(','), min_duration=min_duration, max_duration=max_duration, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.file_occurence = count_occurence(self.collection.mapping) self.featurizer = featurizer self.trim = trim self.load_audio = load_audio self.labels = labels if labels else self.collection.uniq_labels self.num_classes = len(self.labels) self.label2id, self.id2label = {}, {} for label_id, label in enumerate(self.labels): self.label2id[label] = label_id self.id2label[label_id] = label for idx in range(len(self.labels[:5])): logging.debug(" label id {} and its mapped label {}".format( idx, self.id2label[idx])) valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") if isinstance(audio_tar_filepaths, str): # Replace '(' and '[' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "{") # Replace ')' and ']' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "}") # Check for distributed and partition shards accordingly if world_size > 1: if isinstance(audio_tar_filepaths, str): # Brace expand audio_tar_filepaths = list( braceexpand.braceexpand(audio_tar_filepaths)) if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(audio_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(audio_tar_filepaths) // world_size) audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}" ) # Put together WebDataset self._dataset = wd.WebDataset(audio_tar_filepaths) if shuffle_n > 0: self._dataset = self._dataset.shuffle(shuffle_n) else: logging.info( "WebDataset will not shuffle files within the tar files.") self._dataset = (self._dataset.rename( audio='wav', key='__key__').to_tuple('audio', 'key').pipe( self._filter).map(f=self._build_sample))
def main(): parser = argparse.ArgumentParser() parser.add_argument("image_dir", type=Path) parser.add_argument("mask_dir", type=Path) parser.add_argument("lu_dir", type=Path) parser.add_argument("outdir", type=Path) num_cores = psutil.cpu_count(logical=False) parser.add_argument( "--workers", dest="workers", type=int, default=num_cores, help="number of workers for parallel execution [def: %(default)s]", ) parser.add_argument( "--source_dim", dest="source_dim", type=int, default=2048, help="size of input tiles [def: %(default)s]", ) parser.add_argument( "--tile_size", dest="tile_size", type=int, default=256, help="size of final tiles that are then passed to the model [def: %(default)s]", ) parser.add_argument( "--format", dest="format", type=str, default="TIFF", choices=["PNG", "TIFF"], help="target file format (PNG, TIFF) [def: %(default)s]", ) parser.add_argument( "--tmp-dir", dest="tmp_dir", type=Path, default=None, help="use this location as tmp dir", ) parser.add_argument( "--subdir", dest="sub_dir", default="train", help="use this location as sub_dir", ) parser.add_argument( "--stats", dest="stats_file", type=Path, default=Path("stats.csv"), help="use this file to record stats", ) args = parser.parse_args() args.outdir.mkdir(parents=True, exist_ok=True) Path(args.outdir / args.sub_dir).mkdir(parents=True, exist_ok=True) if args.tmp_dir: print(f"Using custom tmp dir: {args.tmp_dir}") Path(args.tmp_dir).mkdir(parents=True, exist_ok=True) if args.format == "TIFF": suffix = "tif" elif args.format == "PNG": suffix = "png" else: raise NotImplementedError SHUFFLE = True # shuffle subtile order within shards (with fixed seed) # subtile_stats = split_tiles(train_files) images = sorted(args.image_dir.glob("*.tif")) masks = sorted(args.mask_dir.glob("*.tif")) lus = sorted(args.lu_dir.glob("*.tif")) image_names = {i.name for i in images} mask_names = {i.name for i in masks} lu_names = {i.name for i in lus} # limit set of images to images that have equivalent mask tiles train_images = [ i for i in images if i.name in image_names.intersection(mask_names).intersection(lu_names) ] train_masks = [ i for i in masks if i.name in mask_names.intersection(image_names).intersection(lu_names) ] train_lus = [ i for i in lus if i.name in lu_names.intersection(mask_names).intersection(image_names) ] train_images = sorted(train_images) train_masks = sorted(train_masks) train_lus = sorted(train_lus) # print(len(train_images)) # print(len(train_masks)) # exit() # print(len(train_lus)) cfg = dict( source_dim=args.source_dim, tile_size=args.tile_size, format=args.format, ) subtile_stats = split_tiles( train_images, train_masks, train_lus, args.workers, str(args.outdir / args.sub_dir / "train-%06d.tar"), **cfg, ) with open(args.outdir / args.stats_file, "w") as fout: fout.write("tile,frac,status\n") for i, (fname, frac, status) in enumerate(subtile_stats): line = f"{fname},{frac},{status}\n" fout.write(line) # rebalance shards so we get similar distributions in all shards with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmpdir: print(f"Created a temporary directory: {tmpdir}") print("Extract source tars") # untar input for tf_name in sorted((args.outdir / args.sub_dir).glob("train-00*.tar")): with tarfile.open(tf_name) as tf: tf.extractall(tmpdir) print("Write balanced shards from deadtree samples") df = pd.read_csv(args.outdir / args.stats_file) df = df[df.status > 0] n_valid = len(df) splits = split_df(df, SHARDSIZE) # preserve last shard if more than 50% of values are present if SHARDSIZE // 2 < len(splits[-1]) < SHARDSIZE: # fill last shard with duplicates (not ideal...) n_missing = SHARDSIZE - len(splits[-1]) # df_extra = splits[-1].sample(n=n_missing, random_state=42) splits[-1].extend(np.random.choice(splits[-1], size=n_missing).tolist()) # drop incomplete shards splits = [x for x in splits if len(x) == SHARDSIZE] assert len(splits) > 0, "Something went wrong" for s_cnt, s in enumerate(splits): with tarfile.open( args.outdir / args.sub_dir / f"train-balanced-{s_cnt:06}.tar", "w" ) as dst: if SHUFFLE: random.shuffle(s) for i in s: dst.add(f"{tmpdir}/{i}.mask.{suffix}", f"{i}.mask.{suffix}") dst.add(f"{tmpdir}/{i}.lu.{suffix}", f"{i}.lu.{suffix}") dst.add(f"{tmpdir}/{i}.rgbn.{suffix}", f"{i}.rgbn.{suffix}") dst.add(f"{tmpdir}/{i}.txt", f"{i}.txt") # create sets for random tile dataset # use all subtiles not covered in train n_subtiles = (args.source_dim // args.tile_size) ** 2 all_subtiles = [] for image_name in image_names: all_subtiles.extend( [f"{Path(image_name).stem}_{c:03}" for c in range(n_subtiles)] ) all_subtiles = set(all_subtiles) n_samples = n_valid * OVERSAMPLE_FACTOR random_subtiles = random.sample( tuple(all_subtiles - set([x[0] for x in subtile_stats if int(x[2]) == 1])), n_samples, ) # the necessary tile to process random_tiles = sorted(list(set([x[:-4] for x in random_subtiles]))) all_images = sorted(args.image_dir.glob("*.tif")) random_images = [x for x in all_images if x.stem in random_tiles] print("STATS") print(len(all_subtiles)) print(len(subtile_stats)) print(len(random_subtiles)) print(len(random_images)) cfg = dict( source_dim=args.source_dim, tile_size=args.tile_size, format=args.format, valid_subtiles=random_subtiles, # subset data with random selection of subtiles ) random_images_names = {i.name for i in random_images} random_lus = [i for i in lus if i.name in random_images_names] subtile_stats_rnd = split_tiles( random_images, [None] * len(random_images), random_lus, args.workers, str(args.outdir / args.sub_dir / "train-randomsamples-%06d.tar"), **cfg, ) stats_file_rnd = Path(args.stats_file.stem + "_rnd.csv") with open(args.outdir / stats_file_rnd, "w") as fout: fout.write("tile,frac,status\n") for i, (fname, frac, status) in enumerate(subtile_stats_rnd): line = f"{fname},{frac},{status}\n" fout.write(line) # also create combo dataset # source A: train-balanced, source B: randomsample # NOTE: combo dataset has double the default shardsize (2*128), samples alternate between regular and random sample train_balanced_shards = [ str(x) for x in sorted((args.outdir / args.sub_dir).glob("train-balanced*")) ] train_balanced_shards_rnd = [ str(x) for x in sorted((args.outdir / args.sub_dir).glob("train-random*")) ] train_balanced_shards_rnd = train_balanced_shards_rnd[: len(train_balanced_shards)] shardpattern = str(args.outdir / args.sub_dir / "train-combo-%06d.tar") with wds.ShardWriter(shardpattern, maxcount=SHARDSIZE * 2) as sink: for shardA, shardB in zip(train_balanced_shards, train_balanced_shards_rnd): for sA, sB in zip(wds.WebDataset(shardA), wds.WebDataset(shardB)): sink.write(sA) sink.write(sB) # remove everything but train & combo for filename in (args.outdir / args.sub_dir).glob("train-random*"): filename.unlink() for filename in (args.outdir / args.sub_dir).glob("train-balanced*"): filename.unlink() for filename in (args.outdir / args.sub_dir).glob("train-0*"): filename.unlink()
def setup( self, split_fractions: List[float] = DeadtreeDatasetConfig.fractions, in_channels: Optional[int] = 4, # change to 3 for rgb training instead of rgbn classes: Optional[int] = 3, # change to 2 for single class (+bg) setup ) -> None: if self.layout == "single_directory": train_shards, valid_shards, test_shards = split_shards( self.data_shards, split_fractions ) else: train_shards, valid_shards, test_shards = self.data_shards train_shards = [str(x) for x in train_shards] valid_shards = [str(x) for x in valid_shards] test_shards = [str(x) for x in test_shards] # determine the length of the dataset shard_size = sum(1 for _ in DataLoader(wds.WebDataset(train_shards[0]))) logger.info( f"Shard size: {shard_size} (estimate base on file: {train_shards[0]})" ) def build_dataset( shards: List[str], bs: int, transform_func: Callable, shuffle: Optional[int] = 128, shard_size: Optional[int] = 128, ) -> wds.WebDataset: return ( wds.WebDataset( shards, length=len(shards) * shard_size // bs, ) .shuffle(shuffle) .map(sample_decoder) .rename(image="rgbn.tif", mask="mask.tif", lu="lu.tif", stats="txt") .map( partial( transform, transform_func=transform_func, in_channels=in_channels, classes=classes, distmap=True, ) ) .to_tuple("image", "mask", "distmap", "lu", "stats") ) self.train_data = build_dataset( train_shards, self.train_dataloader_conf["batch_size"], transform_func=train_transform, shuffle=shard_size, shard_size=shard_size, ) self.val_data = build_dataset( valid_shards, self.val_dataloader_conf["batch_size"], transform_func=val_transform, shuffle=0, shard_size=shard_size, ) if test_shards: self.test_data = build_dataset( test_shards, self.test_dataloader_conf["batch_size"], transform_func=val_transform, shuffle=0, shard_size=shard_size, ) self.extra_train_data = [] self.extra_valid_data = [] if len(self.data_shards_extra) > 0: for bs, shards in zip(self.batch_size_extra, self.data_shards_extra): # split shards between train and val by the same proportion as the main dataset train_frac = len(train_shards) / (len(train_shards) + len(valid_shards)) valid_frac = 1 - train_frac extra_train_shards, extra_valid_shards, _ = split_shards( shards, [train_frac, valid_frac] ) self.extra_train_data.append( build_dataset( extra_train_shards, bs, transform_func=train_transform, shuffle=shard_size, shard_size=shard_size, ) ) self.extra_valid_data.append( build_dataset( extra_valid_shards, bs, transform_func=val_transform, shuffle=0, shard_size=shard_size, ) )
import io import os import pickle import shutil import torch import webdataset as wds from torch.utils.data import TensorDataset import tarfile if __name__ == '__main__': dataset = wds.WebDataset("train_0000000.tar", length=float("inf")) \ .decode(my_decoder_GT).decode(my_decoder_BW).to_tuple("gt.jpg", "train.jpg", "__key__").batched(1) dataloader = torch.utils.data.DataLoader(dataset, batch_size=None) os.makedirs("preprocessed_data_tars", exist_ok=True) for i, (gt, bw, name) in enumerate(dataloader): print(f"{i + 1} / 1024") os.makedirs("preprocessed_data/" + name[0][:name[0].find("/")], exist_ok=True) torch.save(gt, "preprocessed_data/" + name[0] + ".gt.pt") torch.save(bw, "preprocessed_data/" + name[0] + ".train.pt") tar = tarfile.open( f"preprocessed_data_tars/{name[0][:name[0].find('/')]}.tar.bz2", mode="w:bz2") tar.add("preprocessed_data/" + name[0][:name[0].find("/") + 1]) tar.close() shutil.rmtree("preprocessed_data/" + name[0][:name[0].find("/")], ignore_errors=True)
def __init__( self, text_tar_filepaths: str, metadata_path: str, encoder_tokenizer: str, decoder_tokenizer: str, shuffle_n: int = 1, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, reverse_lang_direction: bool = False, ): super(TarredTranslationDataset, self).__init__() self.encoder_tokenizer = encoder_tokenizer self.decoder_tokenizer = decoder_tokenizer self.reverse_lang_direction = reverse_lang_direction self.src_pad_id = encoder_tokenizer.pad_id self.tgt_pad_id = decoder_tokenizer.pad_id valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}") with open(metadata_path, 'r') as f: metadata = json.load(f) self.metadata = metadata if isinstance(text_tar_filepaths, str): # Replace '(', '[', '<' and '_OP_' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "{") # Replace ')', ']', '>' and '_CL_' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "}") if isinstance(text_tar_filepaths, str): # Brace expand text_tar_filepaths = list(braceexpand.braceexpand(text_tar_filepaths)) if shard_strategy == 'scatter': logging.info("All tarred dataset shards will be scattered evenly across all nodes.") if len(text_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size})." ) begin_idx = (len(text_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(text_tar_filepaths) // world_size) logging.info('Begin Index : %d' % (begin_idx)) logging.info('End Index : %d' % (end_idx)) text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx ) elif shard_strategy == 'replicate': logging.info("All tarred dataset shards will be replicated across all nodes.") else: raise ValueError(f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}") self.tarpath = text_tar_filepaths # Put together WebDataset self._dataset = wd.WebDataset(text_tar_filepaths) if shuffle_n > 0: self._dataset = self._dataset.shuffle(shuffle_n) else: logging.info("WebDataset will not shuffle files within the tar files.") self._dataset = self._dataset.rename(pkl='pkl', key='__key__').to_tuple('pkl', 'key').map(f=self._build_sample)
def __init__( self, audio_tar_filepaths: Union[str, List[str]], manifest_filepath: str, parser: Callable, sample_rate: int, int_values: bool = False, augmentor: Optional[ 'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None, shuffle_n: int = 0, min_duration: Optional[float] = None, max_duration: Optional[float] = None, max_utts: int = 0, trim: bool = False, bos_id: Optional[int] = None, eos_id: Optional[int] = None, pad_id: int = 0, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, return_sample_id: bool = False, ): self.manifest_processor = ASRManifestProcessor( manifest_filepath=manifest_filepath, parser=parser, max_duration=max_duration, min_duration=min_duration, max_utts=max_utts, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim self.eos_id = eos_id self.bos_id = bos_id self.pad_id = pad_id self.return_sample_id = return_sample_id audio_tar_filepaths = expand_audio_filepaths( audio_tar_filepaths=audio_tar_filepaths, shard_strategy=shard_strategy, world_size=world_size, global_rank=global_rank, ) # Put together WebDataset self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None) if shuffle_n > 0: self._dataset = self._dataset.shuffle(shuffle_n) else: logging.info( "WebDataset will not shuffle files within the tar files.") self._dataset = (self._dataset.rename( audio='wav;ogg;flac', key='__key__').to_tuple('audio', 'key').pipe(self._filter).pipe( self._loop_offsets).map(f=self._build_sample))