Beispiel #1
0
def from_processed(url: str, train=False):
    urls = sorted(glob.glob(url))
    if train:
        return (
            wds.WebDataset(urls)
            .shuffle(size=10000000, initial=100000)
            .decode()
            .map(lambda d: Example(**d["json"]))
        )
    else:
        return list(wds.WebDataset(urls).decode().map(lambda d: Example(**d["json"])))
Beispiel #2
0
 def build_dataset(
     shards: List[str],
     bs: int,
     transform_func: Callable,
     shuffle: Optional[int] = 128,
     shard_size: Optional[int] = 128,
 ) -> wds.WebDataset:
     return (
         wds.WebDataset(
             shards,
             length=len(shards) * shard_size // bs,
         )
         .shuffle(shuffle)
         .map(sample_decoder)
         .rename(image="rgbn.tif", mask="mask.tif", lu="lu.tif", stats="txt")
         .map(
             partial(
                 transform,
                 transform_func=transform_func,
                 in_channels=in_channels,
                 classes=classes,
                 distmap=True,
             )
         )
         .to_tuple("image", "mask", "distmap", "lu", "stats")
     )
Beispiel #3
0
def display_shard_images(client, bucket, tar_name, objects=2, etl_id=""):
    to_tensor = transforms.Compose([transforms.ToTensor()])
    test_object = (
        wds.WebDataset(
            client.object_url(bucket, tar_name, transform_id=etl_id),
            handler=wds.handlers.warn_and_continue,
        ).decode("rgb").to_tuple("jpg;png;jpeg;npy cls", handler=wds.handlers.warn_and_continue).map_tuple(to_tensor, lambda x: x)
    )

    test_loader = wds.WebLoader(
        test_object,
        batch_size=None,
        shuffle=False,
        num_workers=1,
    )
    test_iter = iter(test_loader)
    row = 0
    _, axarr = plt.subplots((objects // 2), 2, figsize=(12, 12))
    for i in range(objects):
        column = i % 2
        img_tensor, _ = next(test_iter)
        plt.figure()
        img = np.transpose(np.asarray(img_tensor.squeeze()), (1, 2, 0))
        img = np.clip(img, 0, 1)
        axarr[row, column].set_yticks([])
        axarr[row, column].set_xticks([])
        axarr[row, column].imshow(img, interpolation="nearest")
        if column == 1:
            row += 1
    plt.show()
def make_train_loader_wds(args):
    print("=> using WebDataset loader")
    train_transform = make_train_transform(args)
    num_batches = args.trainsize // args.batch_size
    train_dataset = (wds.WebDataset(
        args.trainshards, length=num_batches).shuffle(
            args.shuffle).decode("pil").to_tuple("jpg;png;jpeg cls").map_tuple(
                train_transform, identity))
    if args.distributed:
        # It's good to avoid partial batches when using DistributedDataParallel.
        train_dataset = train_dataset.batched(args.batch_size, partial=False)
    else:
        train_dataset = train_dataset.batched(args.batch_size)
    # WebLoader is just the regular DataLoader with the same convenience methods
    # that WebDataset has.
    train_loader = wds.WebLoader(
        train_dataset,
        batch_size=None,
        shuffle=False,
        num_workers=args.workers,
    )
    if args.distributed:
        # With DDP, we need to make sure that all nodes get the same number of batches;
        # we do that by reusing a little bit of data.
        # Note that you only need to do this when retrofitting code that depends on
        # epoch size. A better way is to iterate through the entire dataset on all nodes.
        dataset_size = 1281167
        number_of_batches = dataset_size // (args.batch_size * args.world_size)
        print("# batches per node = ", number_of_batches)
        train_loader = train_loader.repeat(2).slice(number_of_batches)
        # This only sets the value returned by the len() function; nothing else uses it,
        # but some frameworks care about it.
        train_loader.length = number_of_batches
    return train_loader
Beispiel #5
0
    def __init__(self, manifest_path: str, tar_filepaths: Union[str, List[str]], shuffle_n: int = 128):
        self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True)

        if isinstance(tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in tar_filepaths:
                    tar_filepaths = tar_filepaths.replace(bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in tar_filepaths:
                    tar_filepaths = tar_filepaths.replace(bkey, "}")

        if not HAVE_OMEGACONG_WEBDATASET:
            raise LightningNotInstalledException(self)
        self.audio_dataset = wd.WebDataset(urls=tar_filepaths, nodesplitter=None)

        if shuffle_n > 0:
            self.audio_dataset = self.audio_dataset.shuffle(shuffle_n)
        else:
            logging.info("WebDataset will not shuffle files within the tar files.")

        self.audio_dataset = self.audio_dataset.rename(audio='wav', key='__key__').to_tuple('audio', 'key')
        self.audio_iter = iter(self.audio_dataset)
Beispiel #6
0
    def __init__(
        self,
        *,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: Union[str, List[str]],
        labels: List[str],
        featurizer,
        shuffle_n: int = 0,
        min_duration: Optional[float] = 0.1,
        max_duration: Optional[float] = None,
        trim: bool = False,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
        is_regression_task: bool = False,
    ):
        self.collection = collections.ASRSpeechLabel(
            manifests_files=manifest_filepath,
            min_duration=min_duration,
            max_duration=max_duration,
            index_by_file_id=True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.file_occurence = count_occurence(self.collection.mapping)

        self.featurizer = featurizer
        self.trim = trim

        self.labels = labels if labels else self.collection.uniq_labels
        self.num_classes = len(self.labels)

        self.label2id, self.id2label = {}, {}
        for label_id, label in enumerate(self.labels):
            self.label2id[label] = label_id
            self.id2label[label_id] = label

        for idx in range(len(self.labels[:5])):
            logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))

        audio_tar_filepaths = expand_audio_filepaths(
            audio_tar_filepaths=audio_tar_filepaths,
            shard_strategy=shard_strategy,
            world_size=world_size,
            global_rank=global_rank,
        )
        # Put together WebDataset
        self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info("WebDataset will not shuffle files within the tar files.")

        self._dataset = (
            self._dataset.rename(audio=VALID_FILE_FORMATS, key='__key__')
            .to_tuple('audio', 'key')
            .pipe(self._filter)
            .map(f=self._build_sample)
        )
    def setup(self, stage=None):
        """Downloads the data, parse it and split the data into train, test,
        validation data.

        Args:
            stage: Stage - training or testing
        """

        data_path = self.args.get("train_glob", "/pvc/output/processing")

        train_base_url = data_path + "/train"
        val_base_url = data_path + "/val"
        test_base_url = data_path + "/test"

        train_count = self.get_num_files(train_base_url)
        val_count = self.get_num_files(val_base_url)
        test_count = self.get_num_files(test_base_url)

        train_url = "{}/{}-{}".format(train_base_url, "train",
                                      "{0.." + str(train_count) + "}.tar")
        valid_url = "{}/{}-{}".format(val_base_url, "val",
                                      "{0.." + str(val_count) + "}.tar")
        test_url = "{}/{}-{}".format(test_base_url, "test",
                                     "{0.." + str(test_count) + "}.tar")

        self.train_dataset = (wds.WebDataset(
            train_url,
            handler=wds.warn_and_continue).shuffle(100).decode("pil").rename(
                image="ppm;jpg;jpeg;png",
                info="cls").map_dict(image=self.train_transform).to_tuple(
                    "image", "info").batched(40))

        self.valid_dataset = (wds.WebDataset(
            valid_url,
            handler=wds.warn_and_continue).shuffle(100).decode("pil").rename(
                image="ppm",
                info="cls").map_dict(image=self.valid_transform).to_tuple(
                    "image", "info").batched(20))

        self.test_dataset = (wds.WebDataset(
            test_url,
            handler=wds.warn_and_continue).shuffle(100).decode("pil").rename(
                image="ppm",
                info="cls").map_dict(image=self.valid_transform).to_tuple(
                    "image", "info").batched(20))
Beispiel #8
0
def objective(trial: optuna.trial.Trial) -> float:

    dataset = wds.WebDataset("/run/media/jacob/data/FACT_Dataset/fact-gamma-10-{0000..0062}.tar").shuffle(20000).decode()
    dataset_2 = wds.WebDataset("/run/media/jacob/data/FACT_Dataset/fact-proton-10-{0000..0010}.tar").shuffle(20000).decode()
    test_dataset_2 = wds.WebDataset("/run/media/jacob/data/FACT_Dataset/fact-gamma-10-{0063..0072}.tar").decode()
    test_dataset = wds.WebDataset("/run/media/jacob/data/FACT_Dataset/fact-proton-10-{0011..0013}.tar").decode()
    dataset = SampleEqually([dataset, dataset_2])
    test_dataset = SampleEqually([test_dataset_2, test_dataset])

    train_loader = DataLoader(dataset, num_workers=16, batch_size=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, num_workers=4, batch_size=1, pin_memory=True)

    # We optimize the number of layers, hidden units in each layer and dropouts.
    config = {
        "sample_ratio_one": trial.suggest_uniform("sample_ratio_one", 0.1, 0.9),
        "sample_radius_one": trial.suggest_uniform("sample_radius_one", 0.1, 0.9),
        "sample_max_neighbor": trial.suggest_int("sample_max_neighbor", 8, 72),
        "sample_ratio_two": trial.suggest_uniform("sample_ratio_two", 0.1, 0.9),
        "sample_radius_two": trial.suggest_uniform("sample_radius_two", 0.1, 0.9),
        "fc_1": trial.suggest_int("fc_1", 128, 256),
        "fc_1_out": trial.suggest_int("fc_1_out", 32, 128),
        "fc_2_out": trial.suggest_int("fc_2_out", 16, 96),
        "dropout": trial.suggest_uniform("dropout", 0.1, 0.9),
    }

    num_classes = 2
    import pytorch_lightning as pl
    model = LitPointNet2(num_classes, lr=0.0001, config=config)

    trainer = pl.Trainer(
        logger=True,
        limit_val_batches=10000,
        limit_train_batches=10000,
        checkpoint_callback=False,
        auto_lr_find=True,
        max_epochs=20,
        gpus=1,
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val/loss")],
    )
    trainer.logger.log_hyperparams(config)
    trainer.tune(model=model, train_dataloader=train_loader, val_dataloaders=test_loader)
    trainer.fit(model=model, train_dataloader=train_loader, val_dataloaders=test_loader)

    return trainer.callback_metrics["val/loss"].item()
Beispiel #9
0
def loader(urls, batch_size, workers):
    to_tensor = transforms.Compose([transforms.ToTensor()])
    etl_dataset = (wds.WebDataset(
        urls, handler=wds.handlers.warn_and_continue).decode("rgb").to_tuple(
            "npy cls", handler=wds.handlers.warn_and_continue).map_tuple(
                to_tensor, lambda x: x))
    ds_size = (500 * len(urls)) // batch_size
    etl_dataset = etl_dataset.with_length(ds_size)
    loader = wds.WebLoader(
        etl_dataset,
        batch_size=batch_size,
        num_workers=workers,
    )
    return loader.with_length(ds_size)
Beispiel #10
0
 def __init__(self,
              urls,
              length,
              shuffle_buffer,
              nodesplitter=None,
              memory_cache=None):
     super().__init__()
     self.memory_cache = memory_cache
     self.dataset = wds.WebDataset(
         urls,
         shardshuffle=True if shuffle_buffer > 1 else False,
         length=length,
         nodesplitter=nodesplitter)
     if shuffle_buffer > 1:
         self.dataset = self.dataset.shuffle(shuffle_buffer)
    matplotlib.use('Agg')

    all_tars = []

    model = GAN()
    if torch.cuda.is_available():
        decods = my_decoders(128)
        for root, dirs, files in os.walk("."):
            for file in files:
                if file.endswith(
                        ".tar") and "out" not in root and "out" not in file:
                    all_tars.append(os.path.join(root, file))

        dataset = wds.WebDataset(all_tars, length=float("inf")) \
            .decode(decods.simple_decoder).to_tuple("gt.jpg", "__key__",
                                                    handler=dummy_func).batched(16)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=None,
                                                 num_workers=20,
                                                 collate_fn=coll)

        trainer = pl.Trainer(
            gpus=1,
            log_every_n_steps=10,
            max_epochs=10,
            profiler=False,
            precision=16,
            distributed_backend='ddp')  #, logger=neptune_logger)
    else:
        decods = my_decoders(128)
Beispiel #12
0
    def __init__(
        self,
        text_tar_filepaths: str,
        num_batches: int,
        shuffle_n: int = 0,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 1,
    ):
        super(TarredTextNormalizationDecoderDataset, self).__init__()

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        if isinstance(text_tar_filepaths, str):
            # Replace '(', '[', '<' and '_OP_' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "{")

            # Replace ')', ']', '>' and '_CL_' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "}")

        if isinstance(text_tar_filepaths, str):
            # Brace expand
            text_tar_filepaths = list(
                braceexpand.braceexpand(text_tar_filepaths))

        if shard_strategy == 'scatter':
            logging.info(
                "All tarred dataset shards will be scattered evenly across all nodes."
            )
            if len(text_tar_filepaths) % world_size != 0:
                logging.warning(
                    f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible "
                    f"by number of distributed workers ({world_size}).")
            begin_idx = (len(text_tar_filepaths) // world_size) * global_rank
            end_idx = begin_idx + (len(text_tar_filepaths) // world_size)
            logging.info('Begin Index : %d' % (begin_idx))
            logging.info('End Index : %d' % (end_idx))
            text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx]
            logging.info(
                "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                global_rank, begin_idx, end_idx)

        elif shard_strategy == 'replicate':
            logging.info(
                "All tarred dataset shards will be replicated across all nodes."
            )

        else:
            raise ValueError(
                f"Invalid shard strategy! Allowed values are: {valid_shard_strategies}"
            )

        # Put together WebDataset
        self._dataset = wd.WebDataset(urls=text_tar_filepaths,
                                      nodesplitter=None)
        self.length = num_batches // world_size
        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info(
                "WebDataset will not shuffle files within the tar files.")

        self._dataset = self._dataset.rename(
            pkl='pkl', key='__key__').to_tuple('pkl',
                                               'key').map(f=self._build_sample)
Beispiel #13
0
if ENABLE_WEBDATASET:
    DATASET_SIZE = int(
        1e9
    )  # You need to set a nominal length for the Dataset in order to avoid warnings from DataLoader

    myimg, mycap = WEBDATASET_IMAGE_TEXT_COLUMNS
    image_text_mapping = {myimg: imagetransform, mycap: tokenize}
    image_mapping = {myimg: imagepreproc}

    num_batches = DATASET_SIZE // BATCH_SIZE

    ds = (
        wds.WebDataset(DATASET, length=num_batches)
        # .shuffle(is_shuffle) # Commented out for WebDataset as the behaviour cannot be predicted yet
        .map_dict(**image_text_mapping).map_dict(**image_mapping).
        to_tuple(mycap, myimg).batched(
            BATCH_SIZE, partial=False
        )  # It is good to avoid partial batches when using Distributed training
    )
else:
    ds = TextImageDataset(
        args.image_text_folder,
        text_len=TEXT_SEQ_LEN,
        image_size=IMAGE_SIZE,
        resize_ratio=args.resize_ratio,
        truncate_captions=args.truncate_captions,
        tokenizer=tokenizer,
        shuffle=is_shuffle,
    )

assert len(ds) > 0, 'dataset is empty'
Beispiel #14
0
    image_text_mapping = {
        myimg: imagetransform,
        mycap: tokenize
    }
    image_mapping = {
        myimg: imagepreproc
    }

    def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available.
        if mycap not in item:
            return False
        if myimg not in item:
            return False
        return True
	
    w_dataset = wds.WebDataset(DATASET, handler=wds.warn_and_continue)
    filtered_dataset = w_dataset.select(filter_dataset)
    ds = filtered_dataset.map_dict(**image_text_mapping).map_dict(**image_mapping).to_tuple(mycap, myimg).batched(BATCH_SIZE, partial=True)
else:
    ds = TextImageDataset(
        args.image_text_folder,
        text_len=TEXT_SEQ_LEN,
        image_size=IMAGE_SIZE,
        resize_ratio=args.resize_ratio,
        truncate_captions=args.truncate_captions,
        tokenizer=tokenizer,
        shuffle=is_shuffle,
    )
    assert len(ds) > 0, 'dataset is empty'

if distr_backend.is_root_worker():
Beispiel #15
0
def dataio_prep_shards(hparams):

    # load the meta info json file
    with wds.gopen.gopen(hparams["train_meta"], "rb") as f:
        train_meta = json.load(f)
    with wds.gopen.gopen(hparams["val_meta"], "rb") as f:
        val_meta = json.load(f)

    # define the mapping functions in the data pipeline
    snt_len_sample = int(hparams["sample_rate"] * hparams["sentence_len"])
    label_encoder = sb.dataio.encoder.CategoricalEncoder()
    lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
    label_encoder.load_or_create(
        path=lab_enc_file,
        from_iterables=[train_meta["language_ids"]],
        output_key="lang_id",
    )

    # breakpoint()

    def audio_pipeline(sample_dict: Dict, random_chunk=True):
        key = sample_dict["__key__"]
        language_id = sample_dict["language_id"].decode("ascii")
        audio_tensor = sample_dict["audio.pth"]

        # determine what part of audio sample to use
        audio_tensor = audio_tensor.squeeze()

        if random_chunk:
            if len(audio_tensor) - snt_len_sample - 1 <= 0:
                start = 0
            else:
                start = random.randint(0,
                                       len(audio_tensor) - snt_len_sample - 1)

            stop = start + snt_len_sample
        else:
            start = 0
            stop = len(audio_tensor)

        sig = audio_tensor[start:stop]

        # determine the language ID of the sample
        lang_id_idx = label_encoder.encode_label(language_id)

        return {
            "sig": sig,
            "lang_id_encoded": lang_id_idx,
            "id": key,
        }

    train_data = (wds.WebDataset(
        hparams["train_shards"],
        cache_dir=hparams["shard_cache_dir"],
    ).repeat().shuffle(1000).decode("pil").map(
        partial(audio_pipeline, random_chunk=True)))
    logger.info(
        f"Training data consist of {train_meta['num_data_samples']} samples")

    valid_data = (wds.WebDataset(
        hparams["val_shards"],
        cache_dir=hparams["shard_cache_dir"],
    ).decode("pil").map(partial(audio_pipeline, random_chunk=False)))
    logger.info(
        f"Validation data consist of {val_meta['num_data_samples']} samples")

    return (
        train_data,
        valid_data,
        train_meta["num_data_samples"],
        val_meta["num_data_samples"],
    )
Beispiel #16
0
    def __init__(
        self,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        parser: Callable,
        sample_rate: int,
        int_values: bool = False,
        augmentor: Optional[
            'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
        shuffle_n: int = 0,
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_utts: int = 0,
        trim: bool = False,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        pad_id: int = 0,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
    ):
        self.collection = collections.ASRAudioText(
            manifests_files=manifest_filepath.split(','),
            parser=parser,
            min_duration=min_duration,
            max_duration=max_duration,
            max_number=max_utts,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)
        self.trim = trim
        self.eos_id = eos_id
        self.bos_id = bos_id
        self.pad_id = pad_id

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        if isinstance(audio_tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "}")

        # Check for distributed and partition shards accordingly
        if world_size > 1:
            if isinstance(audio_tar_filepaths, str):
                # Brace expand
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if shard_strategy == 'scatter':
                logging.info(
                    "All tarred dataset shards will be scattered evenly across all nodes."
                )

                if len(audio_tar_filepaths) % world_size != 0:
                    logging.warning(
                        f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                        f"by number of distributed workers ({world_size}).")

                begin_idx = (len(audio_tar_filepaths) //
                             world_size) * global_rank
                end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
                audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
                logging.info(
                    "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                    global_rank, begin_idx, end_idx)

            elif shard_strategy == 'replicate':
                logging.info(
                    "All tarred dataset shards will be replicated across all nodes."
                )

            else:
                raise ValueError(
                    f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
                )

        # Put together WebDataset
        self._dataset = wd.WebDataset(audio_tar_filepaths)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info(
                "WebDataset will not shuffle files within the tar files.")

        self._dataset = (self._dataset.rename(
            audio='wav', key='__key__').to_tuple('audio', 'key').pipe(
                self._filter).map(f=self._build_sample))
def decode_to_torch(sample):
    from torch_geometric.data import Data
    import torch
    points = sample["points.pth"]
    mask = sample["mask.pth"]
    is_gamma = sample["class.cls"]
    result = Data(pos=points,
                  y=mask)  # Just need x,y,z ignore derived features
    return result


if __name__ == "__main__":
    args = default_argument_parser().parse_args()

    dataset = wds.WebDataset(
        "/run/media/jacob/data/FACT_Dataset/fact-train-10-{0000..0040}.tar"
    ).shuffle(2000).decode()
    test_dataset = wds.WebDataset(
        "/run/media/jacob/data/FACT_Dataset/fact-test-5-{0000..0017}.tar"
    ).decode()
    dataset = wds.Processor(dataset, wds.map, decode_to_torch)
    test_dataset = wds.Processor(test_dataset, wds.map, decode_to_torch)

    train_loader = DataLoader(dataset,
                              num_workers=12,
                              batch_size=1,
                              pin_memory=True)
    test_loader = DataLoader(test_dataset,
                             num_workers=1,
                             batch_size=1,
                             pin_memory=True)
Beispiel #18
0
    def __init__(
        self,
        text_tar_filepaths: str,
        metadata_path: str,
        tokenizer,
        max_seq_length: int = 512,
        batch_step: int = None,
        shuffle_n: int = 1,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
    ):
        super(TarredL2RLanguageModelingDataset, self).__init__()

        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.batch_step = batch_step or self.max_seq_length

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}")

        with open(metadata_path, 'r') as f:
            metadata = json.load(f)

        self.metadata = metadata

        if isinstance(text_tar_filepaths, str):
            # Replace '(', '[', '<' and '_OP_' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "{")

            # Replace ')', ']', '>' and '_CL_' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "}")

        if isinstance(text_tar_filepaths, str):
            # Brace expand
            text_tar_filepaths = list(braceexpand.braceexpand(text_tar_filepaths))

        if shard_strategy == 'scatter':
            logging.info("All tarred dataset shards will be scattered evenly across all nodes.")

            if len(text_tar_filepaths) % world_size != 0:
                logging.warning(
                    f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible "
                    f"by number of distributed workers ({world_size})."
                )

            begin_idx = (len(text_tar_filepaths) // world_size) * global_rank
            end_idx = begin_idx + (len(text_tar_filepaths) // world_size)
            text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx]
            logging.info(
                "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx
            )

        elif shard_strategy == 'replicate':
            logging.info("All tarred dataset shards will be replicated across all nodes.")

        else:
            raise ValueError(f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}")

        self.tarpath = text_tar_filepaths

        # Put together WebDataset
        self._dataset = wd.WebDataset(text_tar_filepaths)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info("WebDataset will not shuffle files within the tar files.")

        self._dataset = self._dataset.rename(npy='npy', key='__key__').to_tuple('npy', 'key').map(f=self._build_sample)
Beispiel #19
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("datapath", type=Path, nargs="+")

    parser.add_argument(
        "--frac",
        dest="frac",
        type=float,
        default=1.0,
        help="fraction of tiles to consider [range: 0-1, def: %(default)s]",
    )

    args = parser.parse_args()

    np.random.seed(42)
    print("Using fixed random seed!")

    # constants
    tile_size = 256
    size = tile_size**2

    if isinstance(args.datapath, list):
        tar_files = sorted(
            list(itertools.chain(*[x.glob("*.tar") for x in args.datapath])))
        tif_files = sorted(
            list(itertools.chain(*[x.glob("*.tif") for x in args.datapath])))
    else:
        tar_files = sorted(args.datapath.glob("*.tar"))
        tif_files = sorted(args.datapath.glob("*.tif"))

    n_files = len(tif_files)

    SUBSET = int(round(args.frac * n_files, 0))
    selection = np.random.choice(range(n_files), size=SUBSET, replace=False)

    if len(tar_files) > len(tif_files):
        # webdataset
        dataset = (wds.WebDataset(
            [str(x) for x in tar_files]).map(sample_decoder).rename(
                image="rgbn.tif", mask="mask.tif",
                stats="txt").map_dict(image=transform).to_tuple("image"))
    else:
        # plain source tif dataset
        dataset = TifDataset(tif_files, transform=transform)

    dataset = Subset(dataset, selection)
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            num_workers=1,
                            shuffle=False)

    mean, std = np.zeros(4), np.zeros(4)

    print("\nCalculating STATS")

    print("\nCalculating MEAN")
    cnt = 0

    for i, data in enumerate(tqdm(dataloader)):
        data = data.squeeze(0).numpy()

        # ignore incomplete tiles for stats
        if data.shape[-2] != data.shape[-1]:
            continue

        # check for empty tile and skip (all values are either 0 or 1 in the first band):
        if np.isin(data, [0, 1]).all():
            continue

        subtiles_rgbn = make_blocks_vectorized(data, tile_size)
        for subtile_rgbn in subtiles_rgbn:
            if subtile_rgbn[0].min() != subtile_rgbn[0].max():
                mean += subtile_rgbn.sum((1, 2)) / size
                cnt += 1

    mean /= cnt + 1  # i + 1

    mean_unsqueezed = np.expand_dims(np.expand_dims(mean, 1),
                                     2)  # mean.unsqueeze(1).unsqueeze(2)

    print("\nCalculating STD")
    cnt = 0
    for i, data in enumerate(tqdm(dataloader)):
        data = data.squeeze(0).numpy()

        # ignore incomplete tiles for stats
        if data.shape[-2] != data.shape[-1]:
            continue

        # check for empty tile and skip (all values are either 0 or 1 in the first band):
        if np.isin(data, [0, 1]).all():
            continue

        subtiles_rgbn = make_blocks_vectorized(data, tile_size)
        for subtile_rgbn in subtiles_rgbn:
            if subtile_rgbn[0].min() != subtile_rgbn[0].max():
                std += ((subtile_rgbn - mean_unsqueezed)**2).sum((1, 2)) / size
                cnt += 1

    std /= cnt + 1
    std = np.sqrt(std)  # std.sqrt()

    df = pd.DataFrame({
        "band": ["red", "green", "blue", "nir"],
        "mean": mean.tolist(),
        "std": std.tolist(),
    })
    df = df.set_index("band")

    # report
    info = {
        "sources": [str(x) for x in args.datapath],
        "date": str(datetime.datetime.now()),
        "frac": args.frac,
        "subtiles": cnt,
        "results": json.loads(df.to_json(orient="index")),
    }

    # Serializing json
    with open(args.datapath[0].parent / "processed.images.stats.json",
              "w") as fout:
        fout.write(json.dumps(info, indent=4))
Beispiel #20
0
    def __init__(
        self,
        *,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        labels: List[str],
        featurizer,
        shuffle_n: int = 0,
        min_duration: Optional[float] = 0.1,
        max_duration: Optional[float] = None,
        trim: bool = False,
        load_audio: bool = True,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
    ):
        self.collection = collections.ASRSpeechLabel(
            manifests_files=manifest_filepath.split(','),
            min_duration=min_duration,
            max_duration=max_duration,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.file_occurence = count_occurence(self.collection.mapping)

        self.featurizer = featurizer
        self.trim = trim
        self.load_audio = load_audio

        self.labels = labels if labels else self.collection.uniq_labels
        self.num_classes = len(self.labels)

        self.label2id, self.id2label = {}, {}
        for label_id, label in enumerate(self.labels):
            self.label2id[label] = label_id
            self.id2label[label_id] = label

        for idx in range(len(self.labels[:5])):
            logging.debug(" label id {} and its mapped label {}".format(
                idx, self.id2label[idx]))

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        if isinstance(audio_tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "}")

        # Check for distributed and partition shards accordingly
        if world_size > 1:
            if isinstance(audio_tar_filepaths, str):
                # Brace expand
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if shard_strategy == 'scatter':
                logging.info(
                    "All tarred dataset shards will be scattered evenly across all nodes."
                )

                if len(audio_tar_filepaths) % world_size != 0:
                    logging.warning(
                        f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                        f"by number of distributed workers ({world_size}).")

                begin_idx = (len(audio_tar_filepaths) //
                             world_size) * global_rank
                end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
                audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
                logging.info(
                    "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                    global_rank, begin_idx, end_idx)

            elif shard_strategy == 'replicate':
                logging.info(
                    "All tarred dataset shards will be replicated across all nodes."
                )

            else:
                raise ValueError(
                    f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
                )

        # Put together WebDataset
        self._dataset = wd.WebDataset(audio_tar_filepaths)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info(
                "WebDataset will not shuffle files within the tar files.")

        self._dataset = (self._dataset.rename(
            audio='wav', key='__key__').to_tuple('audio', 'key').pipe(
                self._filter).map(f=self._build_sample))
Beispiel #21
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("image_dir", type=Path)
    parser.add_argument("mask_dir", type=Path)
    parser.add_argument("lu_dir", type=Path)
    parser.add_argument("outdir", type=Path)

    num_cores = psutil.cpu_count(logical=False)
    parser.add_argument(
        "--workers",
        dest="workers",
        type=int,
        default=num_cores,
        help="number of workers for parallel execution [def: %(default)s]",
    )

    parser.add_argument(
        "--source_dim",
        dest="source_dim",
        type=int,
        default=2048,
        help="size of input tiles [def: %(default)s]",
    )

    parser.add_argument(
        "--tile_size",
        dest="tile_size",
        type=int,
        default=256,
        help="size of final tiles that are then passed to the model [def: %(default)s]",
    )

    parser.add_argument(
        "--format",
        dest="format",
        type=str,
        default="TIFF",
        choices=["PNG", "TIFF"],
        help="target file format (PNG, TIFF) [def: %(default)s]",
    )

    parser.add_argument(
        "--tmp-dir",
        dest="tmp_dir",
        type=Path,
        default=None,
        help="use this location as tmp dir",
    )

    parser.add_argument(
        "--subdir",
        dest="sub_dir",
        default="train",
        help="use this location as sub_dir",
    )

    parser.add_argument(
        "--stats",
        dest="stats_file",
        type=Path,
        default=Path("stats.csv"),
        help="use this file to record stats",
    )

    args = parser.parse_args()

    args.outdir.mkdir(parents=True, exist_ok=True)
    Path(args.outdir / args.sub_dir).mkdir(parents=True, exist_ok=True)

    if args.tmp_dir:
        print(f"Using custom tmp dir: {args.tmp_dir}")
        Path(args.tmp_dir).mkdir(parents=True, exist_ok=True)

    if args.format == "TIFF":
        suffix = "tif"
    elif args.format == "PNG":
        suffix = "png"
    else:
        raise NotImplementedError

    SHUFFLE = True  # shuffle subtile order within shards (with fixed seed)

    # subtile_stats = split_tiles(train_files)
    images = sorted(args.image_dir.glob("*.tif"))
    masks = sorted(args.mask_dir.glob("*.tif"))
    lus = sorted(args.lu_dir.glob("*.tif"))

    image_names = {i.name for i in images}
    mask_names = {i.name for i in masks}
    lu_names = {i.name for i in lus}

    # limit set of images to images that have equivalent mask tiles
    train_images = [
        i
        for i in images
        if i.name in image_names.intersection(mask_names).intersection(lu_names)
    ]
    train_masks = [
        i
        for i in masks
        if i.name in mask_names.intersection(image_names).intersection(lu_names)
    ]
    train_lus = [
        i
        for i in lus
        if i.name in lu_names.intersection(mask_names).intersection(image_names)
    ]

    train_images = sorted(train_images)
    train_masks = sorted(train_masks)
    train_lus = sorted(train_lus)

    # print(len(train_images))
    # print(len(train_masks))
    # exit()
    # print(len(train_lus))

    cfg = dict(
        source_dim=args.source_dim,
        tile_size=args.tile_size,
        format=args.format,
    )

    subtile_stats = split_tiles(
        train_images,
        train_masks,
        train_lus,
        args.workers,
        str(args.outdir / args.sub_dir / "train-%06d.tar"),
        **cfg,
    )

    with open(args.outdir / args.stats_file, "w") as fout:
        fout.write("tile,frac,status\n")
        for i, (fname, frac, status) in enumerate(subtile_stats):
            line = f"{fname},{frac},{status}\n"
            fout.write(line)

    # rebalance shards so we get similar distributions in all shards
    with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmpdir:
        print(f"Created a temporary directory: {tmpdir}")

        print("Extract source tars")
        # untar input
        for tf_name in sorted((args.outdir / args.sub_dir).glob("train-00*.tar")):
            with tarfile.open(tf_name) as tf:
                tf.extractall(tmpdir)

        print("Write balanced shards from deadtree samples")
        df = pd.read_csv(args.outdir / args.stats_file)

        df = df[df.status > 0]

        n_valid = len(df)
        splits = split_df(df, SHARDSIZE)

        # preserve last shard if more than 50% of values are present
        if SHARDSIZE // 2 < len(splits[-1]) < SHARDSIZE:
            # fill last shard with duplicates (not ideal...)
            n_missing = SHARDSIZE - len(splits[-1])
            # df_extra = splits[-1].sample(n=n_missing, random_state=42)
            splits[-1].extend(np.random.choice(splits[-1], size=n_missing).tolist())

        # drop incomplete shards
        splits = [x for x in splits if len(x) == SHARDSIZE]
        assert len(splits) > 0, "Something went wrong"

        for s_cnt, s in enumerate(splits):

            with tarfile.open(
                args.outdir / args.sub_dir / f"train-balanced-{s_cnt:06}.tar", "w"
            ) as dst:

                if SHUFFLE:
                    random.shuffle(s)
                for i in s:
                    dst.add(f"{tmpdir}/{i}.mask.{suffix}", f"{i}.mask.{suffix}")
                    dst.add(f"{tmpdir}/{i}.lu.{suffix}", f"{i}.lu.{suffix}")
                    dst.add(f"{tmpdir}/{i}.rgbn.{suffix}", f"{i}.rgbn.{suffix}")
                    dst.add(f"{tmpdir}/{i}.txt", f"{i}.txt")

    # create sets for random tile dataset
    # use all subtiles not covered in train

    n_subtiles = (args.source_dim // args.tile_size) ** 2

    all_subtiles = []
    for image_name in image_names:
        all_subtiles.extend(
            [f"{Path(image_name).stem}_{c:03}" for c in range(n_subtiles)]
        )
    all_subtiles = set(all_subtiles)

    n_samples = n_valid * OVERSAMPLE_FACTOR
    random_subtiles = random.sample(
        tuple(all_subtiles - set([x[0] for x in subtile_stats if int(x[2]) == 1])),
        n_samples,
    )

    # the necessary tile to process
    random_tiles = sorted(list(set([x[:-4] for x in random_subtiles])))

    all_images = sorted(args.image_dir.glob("*.tif"))
    random_images = [x for x in all_images if x.stem in random_tiles]

    print("STATS")
    print(len(all_subtiles))
    print(len(subtile_stats))
    print(len(random_subtiles))
    print(len(random_images))

    cfg = dict(
        source_dim=args.source_dim,
        tile_size=args.tile_size,
        format=args.format,
        valid_subtiles=random_subtiles,  # subset data with random selection of subtiles
    )

    random_images_names = {i.name for i in random_images}
    random_lus = [i for i in lus if i.name in random_images_names]

    subtile_stats_rnd = split_tiles(
        random_images,
        [None] * len(random_images),
        random_lus,
        args.workers,
        str(args.outdir / args.sub_dir / "train-randomsamples-%06d.tar"),
        **cfg,
    )

    stats_file_rnd = Path(args.stats_file.stem + "_rnd.csv")
    with open(args.outdir / stats_file_rnd, "w") as fout:
        fout.write("tile,frac,status\n")
        for i, (fname, frac, status) in enumerate(subtile_stats_rnd):
            line = f"{fname},{frac},{status}\n"
            fout.write(line)

    # also create combo dataset
    # source A: train-balanced, source B: randomsample
    # NOTE: combo dataset has double the default shardsize (2*128), samples alternate between regular and random sample
    train_balanced_shards = [
        str(x) for x in sorted((args.outdir / args.sub_dir).glob("train-balanced*"))
    ]
    train_balanced_shards_rnd = [
        str(x) for x in sorted((args.outdir / args.sub_dir).glob("train-random*"))
    ]
    train_balanced_shards_rnd = train_balanced_shards_rnd[: len(train_balanced_shards)]

    shardpattern = str(args.outdir / args.sub_dir / "train-combo-%06d.tar")

    with wds.ShardWriter(shardpattern, maxcount=SHARDSIZE * 2) as sink:
        for shardA, shardB in zip(train_balanced_shards, train_balanced_shards_rnd):

            for sA, sB in zip(wds.WebDataset(shardA), wds.WebDataset(shardB)):
                sink.write(sA)
                sink.write(sB)

    # remove everything but train & combo
    for filename in (args.outdir / args.sub_dir).glob("train-random*"):
        filename.unlink()
    for filename in (args.outdir / args.sub_dir).glob("train-balanced*"):
        filename.unlink()
    for filename in (args.outdir / args.sub_dir).glob("train-0*"):
        filename.unlink()
Beispiel #22
0
    def setup(
        self,
        split_fractions: List[float] = DeadtreeDatasetConfig.fractions,
        in_channels: Optional[int] = 4,  # change to 3 for rgb training instead of rgbn
        classes: Optional[int] = 3,  # change to 2 for single class (+bg) setup
    ) -> None:
        if self.layout == "single_directory":
            train_shards, valid_shards, test_shards = split_shards(
                self.data_shards, split_fractions
            )
        else:
            train_shards, valid_shards, test_shards = self.data_shards
            train_shards = [str(x) for x in train_shards]
            valid_shards = [str(x) for x in valid_shards]
            test_shards = [str(x) for x in test_shards]

        # determine the length of the dataset
        shard_size = sum(1 for _ in DataLoader(wds.WebDataset(train_shards[0])))
        logger.info(
            f"Shard size: {shard_size} (estimate base on file: {train_shards[0]})"
        )

        def build_dataset(
            shards: List[str],
            bs: int,
            transform_func: Callable,
            shuffle: Optional[int] = 128,
            shard_size: Optional[int] = 128,
        ) -> wds.WebDataset:
            return (
                wds.WebDataset(
                    shards,
                    length=len(shards) * shard_size // bs,
                )
                .shuffle(shuffle)
                .map(sample_decoder)
                .rename(image="rgbn.tif", mask="mask.tif", lu="lu.tif", stats="txt")
                .map(
                    partial(
                        transform,
                        transform_func=transform_func,
                        in_channels=in_channels,
                        classes=classes,
                        distmap=True,
                    )
                )
                .to_tuple("image", "mask", "distmap", "lu", "stats")
            )

        self.train_data = build_dataset(
            train_shards,
            self.train_dataloader_conf["batch_size"],
            transform_func=train_transform,
            shuffle=shard_size,
            shard_size=shard_size,
        )

        self.val_data = build_dataset(
            valid_shards,
            self.val_dataloader_conf["batch_size"],
            transform_func=val_transform,
            shuffle=0,
            shard_size=shard_size,
        )

        if test_shards:
            self.test_data = build_dataset(
                test_shards,
                self.test_dataloader_conf["batch_size"],
                transform_func=val_transform,
                shuffle=0,
                shard_size=shard_size,
            )

        self.extra_train_data = []
        self.extra_valid_data = []

        if len(self.data_shards_extra) > 0:
            for bs, shards in zip(self.batch_size_extra, self.data_shards_extra):
                # split shards between train and val by the same proportion as the main dataset
                train_frac = len(train_shards) / (len(train_shards) + len(valid_shards))
                valid_frac = 1 - train_frac

                extra_train_shards, extra_valid_shards, _ = split_shards(
                    shards, [train_frac, valid_frac]
                )

                self.extra_train_data.append(
                    build_dataset(
                        extra_train_shards,
                        bs,
                        transform_func=train_transform,
                        shuffle=shard_size,
                        shard_size=shard_size,
                    )
                )

                self.extra_valid_data.append(
                    build_dataset(
                        extra_valid_shards,
                        bs,
                        transform_func=val_transform,
                        shuffle=0,
                        shard_size=shard_size,
                    )
                )
import io
import os
import pickle
import shutil
import torch
import webdataset as wds
from torch.utils.data import TensorDataset
import tarfile

if __name__ == '__main__':
    dataset = wds.WebDataset("train_0000000.tar", length=float("inf")) \
        .decode(my_decoder_GT).decode(my_decoder_BW).to_tuple("gt.jpg", "train.jpg", "__key__").batched(1)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)
    os.makedirs("preprocessed_data_tars", exist_ok=True)

    for i, (gt, bw, name) in enumerate(dataloader):
        print(f"{i + 1} / 1024")
        os.makedirs("preprocessed_data/" + name[0][:name[0].find("/")],
                    exist_ok=True)
        torch.save(gt, "preprocessed_data/" + name[0] + ".gt.pt")
        torch.save(bw, "preprocessed_data/" + name[0] + ".train.pt")

        tar = tarfile.open(
            f"preprocessed_data_tars/{name[0][:name[0].find('/')]}.tar.bz2",
            mode="w:bz2")
        tar.add("preprocessed_data/" + name[0][:name[0].find("/") + 1])
        tar.close()

        shutil.rmtree("preprocessed_data/" + name[0][:name[0].find("/")],
                      ignore_errors=True)
    def __init__(
        self,
        text_tar_filepaths: str,
        metadata_path: str,
        encoder_tokenizer: str,
        decoder_tokenizer: str,
        shuffle_n: int = 1,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
        reverse_lang_direction: bool = False,
    ):
        super(TarredTranslationDataset, self).__init__()

        self.encoder_tokenizer = encoder_tokenizer
        self.decoder_tokenizer = decoder_tokenizer
        self.reverse_lang_direction = reverse_lang_direction
        self.src_pad_id = encoder_tokenizer.pad_id
        self.tgt_pad_id = decoder_tokenizer.pad_id

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}")

        with open(metadata_path, 'r') as f:
            metadata = json.load(f)

        self.metadata = metadata

        if isinstance(text_tar_filepaths, str):
            # Replace '(', '[', '<' and '_OP_' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "{")

            # Replace ')', ']', '>' and '_CL_' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "}")

        if isinstance(text_tar_filepaths, str):
            # Brace expand
            text_tar_filepaths = list(braceexpand.braceexpand(text_tar_filepaths))

        if shard_strategy == 'scatter':
            logging.info("All tarred dataset shards will be scattered evenly across all nodes.")
            if len(text_tar_filepaths) % world_size != 0:
                logging.warning(
                    f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible "
                    f"by number of distributed workers ({world_size})."
                )
            begin_idx = (len(text_tar_filepaths) // world_size) * global_rank
            end_idx = begin_idx + (len(text_tar_filepaths) // world_size)
            logging.info('Begin Index : %d' % (begin_idx))
            logging.info('End Index : %d' % (end_idx))
            text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx]
            logging.info(
                "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx
            )

        elif shard_strategy == 'replicate':
            logging.info("All tarred dataset shards will be replicated across all nodes.")

        else:
            raise ValueError(f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}")

        self.tarpath = text_tar_filepaths

        # Put together WebDataset
        self._dataset = wd.WebDataset(text_tar_filepaths)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info("WebDataset will not shuffle files within the tar files.")

        self._dataset = self._dataset.rename(pkl='pkl', key='__key__').to_tuple('pkl', 'key').map(f=self._build_sample)
Beispiel #25
0
    def __init__(
        self,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        parser: Callable,
        sample_rate: int,
        int_values: bool = False,
        augmentor: Optional[
            'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
        shuffle_n: int = 0,
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_utts: int = 0,
        trim: bool = False,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        pad_id: int = 0,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
        return_sample_id: bool = False,
    ):
        self.manifest_processor = ASRManifestProcessor(
            manifest_filepath=manifest_filepath,
            parser=parser,
            max_duration=max_duration,
            min_duration=min_duration,
            max_utts=max_utts,
            bos_id=bos_id,
            eos_id=eos_id,
            pad_id=pad_id,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)
        self.trim = trim
        self.eos_id = eos_id
        self.bos_id = bos_id
        self.pad_id = pad_id
        self.return_sample_id = return_sample_id

        audio_tar_filepaths = expand_audio_filepaths(
            audio_tar_filepaths=audio_tar_filepaths,
            shard_strategy=shard_strategy,
            world_size=world_size,
            global_rank=global_rank,
        )

        # Put together WebDataset
        self._dataset = wd.WebDataset(urls=audio_tar_filepaths,
                                      nodesplitter=None)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info(
                "WebDataset will not shuffle files within the tar files.")

        self._dataset = (self._dataset.rename(
            audio='wav;ogg;flac',
            key='__key__').to_tuple('audio', 'key').pipe(self._filter).pipe(
                self._loop_offsets).map(f=self._build_sample))