def build_dataloader(location, shuffle_dataset, sampling_fraction, config, collate_fn, tokenizer, continuous_iter=True, world_size=1, num_workers=1): size_dicts = {128: 64*8, 256: 32*8, 512: 16*8, 768: 8*8, 1024: 8*8} # TODO: num workers based on dataset size, only top 16 datasets get 2 workers, next 16 get 1 worker and rest are done in main process single_node = world_size == 1 try: train_dataset = Dataset.load_from_disk(location) train_dataset = TokenizerDataset(config, tokenizer, char_to_id, dict(padding="max_length", truncation=True, return_tensors="pt", max_length=config.tokenizer_length), train_dataset) if num_workers > 0: train_loader = DataLoader(train_dataset, sampler=None if single_node else DistributedSampler(train_dataset, shuffle=shuffle_dataset), batch_size=8*8, collate_fn=None, prefetch_factor=4 if num_workers > 0 else None, num_workers=(2*num_workers) if single_node else num_workers) else: train_loader = DataLoader(train_dataset, sampler=None if single_node else DistributedSampler(train_dataset, shuffle=shuffle_dataset), batch_size=8*8, collate_fn=None, num_workers=(2 * num_workers) if single_node else num_workers) train_loader = custom_batching_fn(train_loader, size_dicts, continuous_iter) except: train_dataset = DatasetDict.load_from_disk(location) train_dataset = {k: v for k, v in train_dataset.items() if len(v) >= world_size} train_dataset_sampling_proba = {k: len(v) ** sampling_fraction for k, v in train_dataset.items()} lsum = sum(train_dataset_sampling_proba.values()) train_dataset_sampling_proba = {k: v / lsum for k, v in train_dataset_sampling_proba.items()} train_dataset = {k: TokenizerDataset(config, tokenizer, char_to_id, dict(padding="max_length", truncation=True, return_tensors="pt", max_length=config.tokenizer_length), v) for k, v in train_dataset.items()} # for v in train_dataset.values(): # v.training = False if num_workers > 0: train_loader = {k: DataLoader(v, sampler=None if single_node else DistributedSampler(v, shuffle=shuffle_dataset, ), batch_size=8*8, collate_fn=collate_fn, prefetch_factor=2, num_workers=(2*num_workers) if single_node else num_workers) for k, v in train_dataset.items()} else: train_loader = { k: DataLoader(v, sampler=None if single_node else DistributedSampler(v, shuffle=shuffle_dataset, ), batch_size=8*8, collate_fn=collate_fn, num_workers=(2 * num_workers) if single_node else num_workers) for k, v in train_dataset.items()} train_loader = {k: custom_batching_fn(dataloader, size_dicts, continuous_iter) for k, dataloader in train_loader.items()} train_loader = datadict_iterator(train_loader, train_dataset_sampling_proba) return train_loader
batch_size=256) dataset_filtered.save_to_disk("/home/ahemf/processed/c4_256") fmap = get_filter_mapper(448) dataset_448 = dataset_filtered.map(fmap, batched=True, batch_size=256, remove_columns=['timestamp']) dataset_448 = dataset_448.map( lambda x: dict(text=list(map(lambda y: clean_text(y), x["text"]))), batched=True, batch_size=256) dataset_448.save_to_disk("/home/ahemf/processed/c4_448") c4 = DatasetDict.load_from_disk("/home/ahemf/processed/c4_448") dsets = Dataset.load_from_disk("/home/ahemf/processed/dsets_448") c4['train'] = c4['train'].add_column('dataset', ['c4'] * len(c4['train'])) c4['train'] = c4['train'].remove_columns(['url', 'timestamp']) c4['validation'] = c4['validation'].remove_columns(['url', 'timestamp']) c4['validation'] = c4['validation'].add_column('dataset', ['c4'] * len(c4['validation'])) dataset_col = dsets['dataset'] dsets = dsets.remove_columns(["dataset"]) dsets = dsets.add_column("dataset", dataset_col) c4["train"] = concatenate_datasets([c4["train"], dsets]) c4["train"].save_to_disk("/home/ahemf/processed/c4_extended") c4 = Dataset.load_from_disk("/home/ahemf/processed/c4_extended")
def build_dataloader(location, shuffle_dataset, sampling_fraction, config, collate_fn, tokenizer, size_dicts, continuous_iter=True, world_size=1, num_workers=1): assert max(size_dicts.values()) % min(size_dicts.values()) == 0 single_node = world_size == 1 from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict min_size = gcd_array(size_dicts.values()) prefetch_factor = 2 * (max(size_dicts.values()) // min_size) try: train_dataset = Dataset.load_from_disk(location) train_dataset = TokenizerDataset( config, tokenizer, char_to_id, dict(padding="max_length", truncation=True, return_tensors="pt", max_length=config.tokenizer_length), train_dataset) if num_workers > 0: train_loader = DataLoader( train_dataset, sampler=None if single_node else DistributedSampler( train_dataset, shuffle=shuffle_dataset), batch_size=min_size, collate_fn=collate_fn, shuffle=shuffle_dataset and single_node, prefetch_factor=prefetch_factor, num_workers=num_workers, pin_memory=True) else: train_loader = DataLoader( train_dataset, sampler=None if single_node else DistributedSampler( train_dataset, shuffle=shuffle_dataset), batch_size=min_size, collate_fn=collate_fn, shuffle=shuffle_dataset and single_node, num_workers=0, pin_memory=True) train_loader = custom_batching_fn(train_loader, size_dicts, continuous_iter) except: train_dataset = DatasetDict.load_from_disk(location) train_dataset = { k: v for k, v in train_dataset.items() if len(v) >= world_size } train_dataset_sampling_proba = { k: len(v)**sampling_fraction for k, v in train_dataset.items() } lsum = sum(train_dataset_sampling_proba.values()) train_dataset_sampling_proba = { k: v / lsum for k, v in train_dataset_sampling_proba.items() } train_dataset = { k: TokenizerDataset( config, tokenizer, char_to_id, dict(padding="max_length", truncation=True, return_tensors="pt", max_length=config.tokenizer_length), v) for k, v in train_dataset.items() } # for v in train_dataset.values(): # v.training = False if num_workers > 0: train_loader = { k: DataLoader(v, sampler=None if single_node else DistributedSampler( v, shuffle=shuffle_dataset, ), shuffle=shuffle_dataset and single_node, batch_size=min_size, collate_fn=collate_fn, prefetch_factor=prefetch_factor, num_workers=num_workers) for k, v in train_dataset.items() } else: train_loader = { k: DataLoader(v, sampler=None if single_node else DistributedSampler( v, shuffle=shuffle_dataset, ), shuffle=shuffle_dataset and single_node, batch_size=min_size, collate_fn=collate_fn, num_workers=0) for k, v in train_dataset.items() } train_loader = { k: custom_batching_fn(dataloader, size_dicts, continuous_iter) for k, dataloader in train_loader.items() } train_loader = datadict_iterator(train_loader, train_dataset_sampling_proba) return train_loader
fmap = get_filter_mapper(448) dsets_448 = dsets_256.map(fmap, batched=True, batch_size=256,) dsets_448.save_to_disk("/home/ahemf/processed/dsets_448") ######################################################################################################## from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModel from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict import torch from torch.nn.parallel.replicate import replicate import os os.environ['TOKENIZERS_PARALLELISM'] = "true" device = torch.device("cuda") sbert = SentenceTransformer("paraphrase-mpnet-base-v2").eval().to(device) dset = Dataset.load_from_disk("/home/ahemf/processed_datasets/dsets_448") with torch.no_grad(): dset_sbert = dset.map(lambda x: dict(sbert=sbert.encode(x["text"])), batched=True, batch_size=128) device = torch.device("cuda") def normalize_sbert(x): t = torch.tensor(x["sbert"]).to(device) t = t / t.norm(2, -1, True) return dict(sbert=t.tolist()) dset_sbert = dset_sbert.map(normalize_sbert, batched=True, batch_size=8192) cnt = [1] def add_id(x): ids = list(range(cnt[0], cnt[0]+len(x["text"])))
import os import random from datasets import Dataset, concatenate_datasets random.seed(12345) if __name__ == "__main__": ori_dataset = Dataset.load_from_disk('disk/enwiki_bookcorpus-tiny-disk') rep_dataset = Dataset.load_from_disk('disk/enwiki_bookcorpus-tiny-wrep-disk') ori_num = ori_dataset.num_rows rep_num = rep_dataset.num_rows rep_list = random.sample(range(rep_num), ori_num) start_idx = 0 def dataset_merge(examples): input_ids = examples['input_ids'] input_ids = [ids.detach().numpy().tolist() for ids in input_ids] global start_idx end_idx = start_idx + len(input_ids) slc_list = rep_list[start_idx:end_idx] print(start_idx, end_idx) start_idx = end_idx original_sent = [] synonym_sent = [] antonym_sent = [] synonym_antonym_sent = [] replace_label = [] for s in slc_list: t_d = rep_dataset[s] original_sent.append(t_d['original_sent'])
def dataset_builder(location, params): from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict dataset = Dataset.load_from_disk(location) dataset = MTTDataset(dataset=dataset, **params) return dataset