def make_train_loader(epoch_size=1100000, batch_size=64, shuffle=20000): # num_batches = (epoch_size + batch_size - 1) // batch_size if True: image_transform = torchvision.transforms.Compose([ torchvision.transforms.RandomResizedCrop(224), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor(), normalize, ]) else: image_transform = torchvision.transforms.Compose([ torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), normalize, ]) dataset = (wds.Dataset( trainurls, handler=wds.warn_and_stop, length=epoch_size).shuffle(shuffle).decode( "pil", handler=wds.warn_and_continue).to_tuple( "ppm;jpg;jpeg;png", "cls", handler=wds.warn_and_continue).map_tuple( image_transform, identity, handler=wds.warn_and_continue).batched(batch_size)) loader = data.DataLoader(dataset, batch_size=None, num_workers=4) return loader
def test_torch_map_dict_decoder(): def image_decoder(data): with io.BytesIO(data) as stream: img = Image.open(stream) img.load() img = img.convert("RGB") result = np.asarray(img) result = np.array(result.transpose(2, 0, 1)) return torch.tensor(result) / 255.0 def mask_decoder(data): with io.BytesIO(data) as stream: img = Image.open(stream) img.load() img = img.convert("L") result = np.asarray(img) return torch.tensor(result) ds = (wds.Dataset(test_data).rename( image="rgb.png", mask="msk.png").map_dict(image=image_decoder, mask=mask_decoder).to_tuple("image", "mask")) image, mask = next(iter(ds)) assert (image.shape, mask.shape) == ((3, 512, 512), (512, 512))
def __init__(self, manifest_path: str, tar_filepaths: Union[str, List[str]], shuffle_n: int = 128): self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser( []), index_by_file_id=True) if isinstance(tar_filepaths, str): # Replace '(' and '[' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in tar_filepaths: tar_filepaths = tar_filepaths.replace(bkey, "{") # Replace ')' and ']' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in tar_filepaths: tar_filepaths = tar_filepaths.replace(bkey, "}") self.audio_dataset = ( wd.Dataset(tar_filepaths).shuffle(shuffle_n).rename( audio='wav', key='__key__').to_tuple('audio', 'key')) self.audio_iter = iter(self.audio_dataset)
def setup(self, stage=None): """Downloads the data, parse it and split the data into train, test, validation data. Args: stage: Stage - training or testing """ data_path = self.args.get("train_glob", "/pvc/output/processing") train_base_url = data_path + "/train" val_base_url = data_path + "/val" test_base_url = data_path + "/test" train_count = self.get_num_files(train_base_url) val_count = self.get_num_files(val_base_url) test_count = self.get_num_files(test_base_url) train_url = "{}/{}-{}".format(train_base_url, "train", "{0.." + str(train_count) + "}.tar") valid_url = "{}/{}-{}".format(val_base_url, "val", "{0.." + str(val_count) + "}.tar") test_url = "{}/{}-{}".format(test_base_url, "test", "{0.." + str(test_count) + "}.tar") self.train_dataset = (wds.Dataset( train_url, handler=wds.warn_and_continue, length=40000 // 40).shuffle(100).decode("pil").rename( image="ppm;jpg;jpeg;png", info="cls").map_dict(image=self.train_transform).to_tuple( "image", "info").batched(40)) self.valid_dataset = (wds.Dataset( valid_url, handler=wds.warn_and_continue, length=10000 // 20).shuffle(100).decode("pil").rename( image="ppm", info="cls").map_dict(image=self.valid_transform).to_tuple( "image", "info").batched(20)) self.test_dataset = (wds.Dataset( test_url, handler=wds.warn_and_continue, length=10000 // 20).shuffle(100).decode("pil").rename( image="ppm", info="cls").map_dict(image=self.valid_transform).to_tuple( "image", "info").batched(20))
def KineticsSounds(cfg, split): if split == 'train': max_idx = 19 elif split == 'val': max_idx = 1 elif split == 'test': max_idx = 2 dataset_root = cfg.DATASET_ROOT if dataset_root.endswith('/'): dataset_root = dataset_root[:-1] url = f"{dataset_root}/KineticsSounds/shards-{split}/shard-{{000000..{max_idx:06d}}}.tar" if cfg.STORAGE_SAS_KEY: url += cfg.STORAGE_SAS_KEY _decoder = Decoder(cfg, "KineticsSounds", split) if split == 'train': batch_size = int(cfg.TRAIN.BATCH_SIZE / cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS) batch_size = int(batch_size / du.get_world_size()) length = int(cfg.TRAIN.DATASET_SIZE / du.get_world_size()) nominal = int(length / batch_size) elif split == 'val': batch_size = int(cfg.TRAIN.BATCH_SIZE / du.get_world_size()) length = int(cfg.VAL.DATASET_SIZE / du.get_world_size()) nominal = int(length / batch_size) elif split == 'test': batch_size = int(cfg.TEST.BATCH_SIZE / du.get_world_size()) length = math.ceil(cfg.TEST.DATASET_SIZE / du.get_world_size()) nominal = math.ceil(length / batch_size) wds.filters.batched = wds.filters.Curried( partial(wds.filters.batched_, collation_fn=COLLATE_FN["kinetics"])) dataset = wds.Dataset( url, handler=wds.warn_and_continue, shard_selection=du.shard_selection, length=length, ) if split == 'train': dataset = dataset.shuffle(100) dataset = (dataset.map_dict( handler=wds.warn_and_continue, mp4=_decoder.mp4decode, json=_decoder.jsondecode, )) if cfg.DATA_LOADER.NUM_WORKERS > 0: length = nominal else: nominal = length dataset = wds.ResizedDataset( dataset, length=length, nominal=nominal, ) return dataset
def main(args): device = "cuda" preproc = tf.Compose([tf.Resize(256), tf.CenterCrop(256), tf.ToTensor()]) url = '../data_celeba_tar/train_{0..162}.tar' dataset = (wds.Dataset( url, length=162000 // 32).shuffle(500).decode("pil").to_tuple( "input.jpg", "sensitive.cls").map_tuple(preproc, identity).batched(32)) loader = DataLoader(dataset, batch_size=None, num_workers=16) #loader = DataLoader(dataset, batch_size=32, shuffle=True) model = VQVAE(cout=30) checkpoints = f"/scratch/xgitiaux/checkpoint/vqvae/two_q_vqvae_017.pt" logger.info(f'Loading checkpoint {checkpoints}') checkpoint = torch.load(checkpoints, map_location='cpu') new_state_dict = OrderedDict() for k, v in checkpoint.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) model = model.to(device) if torch.cuda.device_count() > 1: logger.info(f'Number of gpu is {torch.cuda.device_count()}') model = _CustomDataParallel(model) entropy_coder = MLP(32 * 32, depth=3, width=256).to(device) if torch.cuda.device_count() > 1: logger.info(f'Number of gpu is {torch.cuda.device_count()}') entropy_coder = _CustomDataParallel(entropy_coder) #PixelCNN(ncode=512, channels_in=1).to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) poptimizer = optim.Adam(entropy_coder.parameters(), lr=args.lr) scheduler = None for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device, entropy_coder, poptimizer)
def make_val_loader(epoch_size=50000, batch_size=64): val_transform = torchvision.transforms.Compose([ torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), normalize, ]) val_dataset = (wds.Dataset( valurls, handler=wds.warn_and_stop, prepare_for_worker=False).decode( "pil", handler=wds.warn_and_continue).to_tuple( "ppm;jpg;jpeg;png", "cls", handler=wds.warn_and_continue).map_tuple( val_transform, identity, handler=wds.warn_and_continue).batched(batch_size)) val_loader = data.DataLoader(val_dataset, batch_size=None, num_workers=4) return val_loader
def load_gqn_dataset(name, batch_size, seed=42, shuffle=False, target_transform=transform_viewpoint, max_samples_per_environment=-1): dataset_name, split = split_name(name) assert dataset_name in _DATASET_INFO, f'Dataset {dataset_name} is not supported' assert split in ['test', 'train'], f'Split {split} is not supported' dataset_info = _DATASET_INFO[dataset_name] size = dataset_info[f'{split}_size'] url = os.path.join( DATASET_PATH, f'{dataset_name}-wd', f'{dataset_name}-{split}-{{000001..{size:06d}}}-of-{size:06d}.tar') sample_size = dataset_info['max_num_views'] + 1 environment_size = dataset_info['sequence_size'] dataset = wds.Dataset(url) rng = random.Random(seed) dataset.rng = rng dataset.reseed_hook = dataset.reseed_rng if shuffle: dataset.shard_shuffle = wds.dataset.Shuffler(rng) dataset = dataset.pipe( partial(sample_environment, sample_size=sample_size, environment_size=environment_size, shuffle=shuffle, rng=rng, max_samples_per_env=max_samples_per_environment)) if shuffle: dataset = dataset.pipe( wds.filters.shuffle(10000, rng=rng, initial=1000)) dataset = dataset.to_tuple('camera.pth', 'image.jpg') dataset = dataset.batched(batch_size) dataset = dataset.pipe( partial(transform_batch, rng=rng, target_transform=transform_viewpoint)) return dataset
def distribute_remaining_data(data_path, subset, total_instance, chunks): remaining_chunks = len(chunks) % total_instance if remaining_chunks == 0: remaining_chunks = total_instance remaining_start, remaining_end = chunks[-remaining_chunks], chunks[-1] chunk_str = f"{remaining_start}..{remaining_end}" dataset = wds.Dataset( os.path.join(data_path, subset + "-{" + chunk_str + "}.tar")) # determine saving format with open(os.path.join(data_path, "metadata.json")) as metadata_file: metadata = json.load(metadata_file) data_format = metadata["format"] dataset = decode_webdataset(dataset, data_format, identity, use_bbox_info=True) folder = os.path.join(data_path, "distributed", str(total_instance) + "-instances") if not os.path.exists(folder): os.makedirs(folder) distributed_save_path = os.path.join(folder, subset + "-%06d.tar") save_remaining_data(dataset, distributed_save_path, total_instance)
def train_dataloader(self): # This "if" statement is the only difference between # WebDataset and torchvision.datasets.ImageNet if self.imagenet in [None, ""]: dataset = ( wds.Dataset(self.trainurls, handler=wds.warn_and_continue) .shuffle(5000) .decode("pil", handler=wds.warn_and_continue) .to_tuple("ppm;jpg;jpeg;png", "cls") .map_tuple(image_transform, identity) ) num_batches = (self.epoch + self.batch_size - 1) // self.batch_size dataset = wds.ResizedDataset(dataset, self.epoch, nominal=num_batches) else: dataset = torchvision.datasets.ImageNet( self.imagenet, split="train", transform=image_transform ) dataset = wds.ResizedDataset(dataset, self.epoch) loader = data.DataLoader( dataset, batch_size=self.batch_size, num_workers=self.num_workers ) return loader
import torch import torchvision import webdataset as wds from itertools import islice import tempfile url = "testoutput.tar" def mp4decode(data): with tempfile.TemporaryDirectory() as dname: with open(dname + "/sample.mp4", "wb") as stream: stream.write(data) vframes, aframes, info = torchvision.io.read_video(dname + "/sample.mp4") return vframes, aframes, info dataset = (wds.Dataset(url).decode().map_dict(mp4=mp4decode)) for sample in islice(dataset, 0, 3): print("---") print(list(sample.keys())) vframes, aframes, info = sample["mp4"] print(vframes.shape, aframes.shape) print(info)
def __init__( self, *, audio_tar_filepaths: Union[str, List[str]], manifest_filepath: str, labels: List[str], featurizer, shuffle_n: int = 0, min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, load_audio: bool = True, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, ): self.collection = collections.ASRSpeechLabel( manifests_files=manifest_filepath.split(','), min_duration=min_duration, max_duration=max_duration, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.file_occurence = count_occurence(self.collection.mapping) self.featurizer = featurizer self.trim = trim self.load_audio = load_audio self.labels = labels if labels else self.collection.uniq_labels self.num_classes = len(self.labels) self.label2id, self.id2label = {}, {} for label_id, label in enumerate(self.labels): self.label2id[label] = label_id self.id2label[label_id] = label for idx in range(len(self.labels[:5])): logging.debug(" label id {} and its mapped label {}".format( idx, self.id2label[idx])) valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") if isinstance(audio_tar_filepaths, str): # Replace '(' and '[' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "{") # Replace ')' and ']' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "}") # Check for distributed and partition shards accordingly if world_size > 1: if isinstance(audio_tar_filepaths, str): # Brace expand audio_tar_filepaths = list( braceexpand.braceexpand(audio_tar_filepaths)) if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(audio_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(audio_tar_filepaths) // world_size) audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}" ) # Put together WebDataset self._dataset = ( wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename( audio='wav', key='__key__').to_tuple('audio', 'key').pipe( self._filter).map(f=self._build_sample))
def __init__( self, audio_tar_filepaths, manifest_filepath, labels, batch_size, sample_rate=16000, int_values=False, bos_id=None, eos_id=None, pad_id=None, min_duration=0.1, max_duration=None, normalize_transcripts=True, trim_silence=False, shuffle_n=0, num_workers=0, augmentor: Optional[Union[AudioAugmentor, Dict[str, Dict[str, Any]]]] = None, ): super().__init__() self._sample_rate = sample_rate if augmentor is not None: augmentor = _process_augmentations(augmentor) self.collection = ASRAudioText( manifests_files=manifest_filepath.split(','), parser=make_parser(labels=labels, name='en', do_normalize=normalize_transcripts), min_duration=min_duration, max_duration=max_duration, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.featurizer = WaveformFeaturizer(sample_rate=self._sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim_silence self.eos_id = eos_id self.bos_id = bos_id # Used in creating a sampler (in Actions). self._batch_size = batch_size self._num_workers = num_workers pad_id = 0 if pad_id is None else pad_id self.collate_fn = partial(seq_collate_fn, token_pad_value=pad_id) # Check for distributed and partition shards accordingly if torch.distributed.is_initialized(): global_rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() if isinstance(audio_tar_filepaths, str): audio_tar_filepaths = list( braceexpand.braceexpand(audio_tar_filepaths)) if len(audio_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(audio_tar_filepaths) // world_size) audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx] # Put together WebDataset self._dataset = ( wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename( audio='wav', key='__key__').to_tuple('audio', 'key').pipe( self._filter).map(f=self._build_sample))
def test_dataset_shuffle_extract(): ds = wds.Dataset(test_data).shuffle(5).to_tuple("msk.png rgb.png") assert count_samples_tuple(ds) == 64
def __init__( self, audio_tar_filepaths: Union[str, List[str]], manifest_filepath: str, parser: Callable, sample_rate: int, int_values: bool = False, augmentor: Optional[ 'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None, shuffle_n: int = 0, min_duration: Optional[float] = None, max_duration: Optional[float] = None, max_utts: int = 0, trim: bool = False, bos_id: Optional[int] = None, eos_id: Optional[int] = None, add_misc: bool = False, pad_id: int = 0, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, ): self.collection = collections.ASRAudioText( manifests_files=manifest_filepath.split(','), parser=parser, min_duration=min_duration, max_duration=max_duration, max_number=max_utts, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim self.eos_id = eos_id self.bos_id = bos_id self.pad_id = pad_id self._add_misc = add_misc valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") if isinstance(audio_tar_filepaths, str): # Replace '(' and '[' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "{") # Replace ')' and ']' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "}") # Check for distributed and partition shards accordingly if world_size > 1: if isinstance(audio_tar_filepaths, str): # Brace expand audio_tar_filepaths = list( braceexpand.braceexpand(audio_tar_filepaths)) if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(audio_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(audio_tar_filepaths) // world_size) audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}" ) # Put together WebDataset self._dataset = ( wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename( audio='wav', key='__key__').to_tuple('audio', 'key').pipe( self._filter).map(f=self._build_sample))
def _file_iter_to_line_iter(jsonl_iter): for jsonl in jsonl_iter: lines = jsonl["jsonl"].split(b"\n") for line in lines: if not line: continue json_line = json.loads(line) json_line["binary"] = jsonl["__key__"] yield json_line if __name__ == "__main__": print(sys.argv[1]) urls = sorted(glob.glob(sys.argv[1])) dataset = wds.Dataset(urls).pipe(_file_iter_to_line_iter) dataset = torch.utils.data.DataLoader(dataset, num_workers=8, batch_size=None) uniq_code = set() uniq_binary = set() token_len = [] num_vars = [] def tokenlen(example): return len(example["code_tokens"]) def num_var(example): return len(example["source"]) def name(example):
def main(args): device = "cuda" preproc = tf.Compose([tf.Resize(256), tf.CenterCrop(256), tf.ToTensor()]) url = '../data_celeba_tar/train_{0..162}.tar' dataset = (wds.Dataset( url, length=162000 // 16).shuffle(500).decode("pil").to_tuple( "input.jpg", "sensitive.cls").map_tuple(preproc, identity).batched(16)) url = '../data_vae' dataset = CelebA(url) loader = DataLoader(dataset, batch_size=64, num_workers=16, drop_last=True) #loader = DataLoader(dataset, batch_size=32, shuffle=True) model = VQVAE(cout=30).to(device) if torch.cuda.device_count() > 1: logger.info(f'Number of gpu is {torch.cuda.device_count()}') model = _CustomDataParallel(model) entropy_coder = PixelSNAIL( [32, 32], 512, 64, 5, 2, 2, 64, n_out_res_block=0, ).to(device) # entropy_coder_bottom = PixelSNAIL( # [64, 64], # 512, # 64, # 5, # 2, # 2, # 64, # n_out_res_block=0, # n_cond_res_block=2, # cond_res_channel=64, # attention=False # ).to(device) entropy_coder = MLP(32 * 32, depth=3, width=256).to(device) if torch.cuda.device_count() > 1: logger.info(f'Number of gpu is {torch.cuda.device_count()}') entropy_coder = _CustomDataParallel(entropy_coder) #entropy_coder_bottom = _CustomDataParallel(entropy_coder_bottom) #PixelCNN(ncode=512, channels_in=1).to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) poptimizer = optim.Adam( list(entropy_coder.parameters() ), #+ list(entropy_coder_bottom.parameters()), lr=args.lr) scheduler = None for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device, entropy_coder, entropy_coder, poptimizer) os.makedirs("/scratch/xgitiaux/checkpoint/vqvae", exist_ok=True) torch.save( model.state_dict(), f"/scratch/xgitiaux/checkpoint/vqvae/two_q_vqvae_{str(i + 1).zfill(3)}.pt" )
def sample_batch(): ds = (wds.Dataset(test_data).map(semsegment_decoder).rename( image="rgb.png", mask="msk.png").to_tuple("image", "mask").batched(bs)) sample = next(iter(ds)) return sample[0]
def test_rename(): ds = wds.Dataset(test_data).rename(image="rgb.png", mask="msk.png") sample = next(iter(ds)) assert set(sample.keys()) == {"image", "mask"}
def test_slice(): ds = wds.Dataset(test_data).slice(10) assert count_samples_tuple(ds) == 10
def test_dataset_pipe_cat(): ds = wds.Dataset(f"pipe:cat {test_data}").shuffle(5).to_tuple( "msk.png rgb.png") assert count_samples_tuple(ds) == 64
def test_multi(): for k in [1, 4, 17]: urls = [f"pipe:cat {local_data} # {i}" for i in range(k)] ds = wds.Dataset(urls).decode().shuffle(5).to_tuple("png;jpg cls") mds = multi.MultiDataset(ds, workers=4) assert count_samples_tuple(mds) == 47 * k
def __init__( self, text_tar_filepaths: str, metadata_path: str, encoder_tokenizer: str, decoder_tokenizer: str, shuffle_n: int = 1, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, reverse_lang_direction: bool = False, ): super(TarredTranslationDataset, self).__init__() self.encoder_tokenizer = encoder_tokenizer self.decoder_tokenizer = decoder_tokenizer self.reverse_lang_direction = reverse_lang_direction self.src_pad_id = encoder_tokenizer.pad_id self.tgt_pad_id = decoder_tokenizer.pad_id valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") with open(metadata_path, 'r') as f: metadata = json.load(f) self.metadata = metadata if isinstance(text_tar_filepaths, str): # Replace '(', '[', '<' and '_OP_' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "{") # Replace ')', ']', '>' and '_CL_' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "}") if isinstance(text_tar_filepaths, str): # Brace expand text_tar_filepaths = list( braceexpand.braceexpand(text_tar_filepaths)) if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(text_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(text_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(text_tar_filepaths) // world_size) logging.info('Begin Index : %d' % (begin_idx)) logging.info('End Index : %d' % (end_idx)) text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}" ) self.tarpath = text_tar_filepaths # Put together WebDataset self._dataset = ( wd.Dataset(text_tar_filepaths).shuffle(shuffle_n).rename( pkl='pkl', key='__key__').to_tuple('pkl', 'key').map(f=self._build_sample))
def __init__( self, text_tar_filepaths: str, metadata_path: str, tokenizer, max_seq_length: int = 512, batch_step: int = None, shuffle_n: int = 1, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, ): super(TarredL2RLanguageModelingDataset, self).__init__() self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.batch_step = batch_step or self.max_seq_length valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") with open(metadata_path, 'r') as f: metadata = json.load(f) self.metadata = metadata if isinstance(text_tar_filepaths, str): # Replace '(', '[', '<' and '_OP_' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "{") # Replace ')', ']', '>' and '_CL_' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "}") if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(text_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(text_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(text_tar_filepaths) // world_size) text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}" ) self.tarpath = text_tar_filepaths # Put together WebDataset self._dataset = ( wd.Dataset(text_tar_filepaths).shuffle(shuffle_n).rename( npy='npy', key='__key__').to_tuple('npy', 'key').map(f=self._build_sample))
def test_dataset(): ds = wds.Dataset(test_data) assert count_samples_tuple(ds) == 64