def evaluate_ocnli(model, dev_dataloader, device, args): model.eval() correct = 0 total = 0 with torch.no_grad(): for batch in tqdm.tqdm(dev_dataloader): tokens_1, masks_1, tokens_2, masks_2, tokens_3, masks_3, labels = [x.to(device) for x in batch] tokens, attention_mask, position_ids = get_batch(tokens_1, args) output, _ = model(tokens, position_ids, attention_mask) losses = mpu.vocab_parallel_cross_entropy(output[:, :-1, :].contiguous().float(), tokens[:, 1:]) output_1 = torch.sum(losses * masks_1, 1) / torch.sum(masks_1, -1) tensor_list = [torch.zeros_like(output_1) for _ in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather(tensor_list, output_1, mpu.get_data_parallel_group()) output_1 = torch.stack(tensor_list, 0).view(-1).cpu().detach().numpy() # -------------- tokens, attention_mask, position_ids = get_batch(tokens_2, args) output, _ = model(tokens, position_ids, attention_mask) losses = mpu.vocab_parallel_cross_entropy(output[:, :-1, :].contiguous().float(), tokens[:, 1:]) output_2 = torch.sum(losses * masks_2, 1) / torch.sum(masks_2, -1) tensor_list = [torch.zeros_like(output_2) for _ in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather(tensor_list, output_2, mpu.get_data_parallel_group()) output_2 = torch.stack(tensor_list, 0).view(-1).cpu().detach().numpy() # --------------- tokens, attention_mask, position_ids = get_batch(tokens_3, args) output, _ = model(tokens, position_ids, attention_mask) losses = mpu.vocab_parallel_cross_entropy(output[:, :-1, :].contiguous().float(), tokens[:, 1:]) output_3 = torch.sum(losses * masks_3, 1) / torch.sum(masks_3, -1) tensor_list = [torch.zeros_like(output_3) for _ in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather(tensor_list, output_3, mpu.get_data_parallel_group()) output_3 = torch.stack(tensor_list, 0).view(-1).cpu().detach().numpy() # -------------- tensor_list_labels = [torch.zeros_like(labels) for _ in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather(tensor_list_labels, labels, mpu.get_data_parallel_group()) if torch.distributed.get_rank() == 0: labels = torch.stack(tensor_list_labels, 0) labels = labels.view(-1).cpu().detach().numpy() res = [np.argmin(np.array(x)) for x in zip(output_1, output_2, output_3)] res = [x==y for x, y in zip(res, labels)] correct += sum(res) total += len(res) if torch.distributed.get_rank() == 0: print("EVAL", correct, total)
def evaluate(model, dev_dataloader, all_labels, device, args): model.eval() if torch.distributed.get_rank() == 0: res = [] with torch.no_grad(): for batch in tqdm.tqdm(dev_dataloader): tokens, masks = [x.to(device) for x in batch] tokens, attention_mask, position_ids = get_batch(tokens, args) output, _ = model(tokens, position_ids, attention_mask) losses = mpu.vocab_parallel_cross_entropy(output[:, :-1, :].contiguous().float(), tokens[:, 1:]) output = torch.sum(losses * masks, 1) / torch.sum(masks, -1) tensor_list = [torch.zeros_like(output) for _ in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather(tensor_list, output, mpu.get_data_parallel_group()) output = torch.stack(tensor_list, 0).view(-1).cpu().detach().numpy() if torch.distributed.get_rank() == 0: for v in output: res.append(v) if torch.distributed.get_rank() == 0: cnt = 0 label_size = max(all_labels) + 1 num_inst = len(res) // label_size for x in range(num_inst): label = all_labels[x] cur_res = res[x*label_size:(x+1)*label_size] pos = np.argmin(cur_res) if pos == label: cnt += 1 print("EVAL", cnt, num_inst)
def load_data(args, data_type, tokenizer, ratio=1): data_path = args.data_dir # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers # Dataset filename = os.path.join(data_path, data_type + '.json') dataset = CHIDDataset(args, filename, data_type, tokenizer, ratio=ratio) # Use a random sampler with distributed batch sampler. if data_type == 'train': sampler = RandomSampler(dataset) else: sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True, collate_fn=dataset.collate), dataset
def build_data_loader(dataset, batch_size, num_workers, drop_last, shuffle=True, only_rank0=False): """Data loader. Note that batch-size is the local (per GPU) batch-size.""" # Sampler. if only_rank0: rank, world_size = 0, 1 else: world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) # Data loader. Note that batch size is the per GPU batch size. data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=sampler, shuffle=False, num_workers=num_workers, drop_last=drop_last, pin_memory=True, collate_fn=my_collate) return data_loader
def evaluate_tnews(args, model, dataloader, device, mode="dev"): model.eval() all_truth, all_preds = [], [] with torch.no_grad(): for batch, no_model_batch in tqdm(dataloader, desc="Evaluating {}".format(mode), disable=(torch.distributed.get_rank() != 0)): for k in batch: batch[k] = batch[k].to(device) for k in no_model_batch: no_model_batch[k] = no_model_batch[k].to(device) output = model(**batch) output = torch.sum(output * no_model_batch["loss_mask"].unsqueeze(-1), 1) / torch.sum( no_model_batch["loss_mask"], -1).unsqueeze(-1) # gather the output logits from other gpus tensor_list = [torch.zeros_like(output) for _ in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather(tensor_list, output, mpu.get_data_parallel_group()) # gather the truth labels from other gpus tensor_list_truth = [torch.zeros_like(no_model_batch["truth"], dtype=torch.long) for _ in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather(tensor_list_truth, no_model_batch["truth"], mpu.get_data_parallel_group()) if args.model_parallel_size == 1: scores = torch.stack(tensor_list, 0).view(-1, 30000) else: assert args.model_parallel_size == 2, "Now, we only support model parallel <= 2" # for convience implementation. Note that the truth labels only appears in the first 15000 part of the logits, e.g. on rank 0, 2, 4, ... scores = torch.stack(tensor_list, 0).view(-1, 15000) truth = torch.stack(tensor_list_truth, 0) truth = truth.view(-1) # scores = scores[:, cand_ids] preds = torch.argmax(scores, dim=-1) all_truth.extend(truth.detach().cpu().tolist()) all_preds.extend(preds.detach().cpu().tolist()) acc = sum([int(p == l) for p, l in zip(all_preds, all_truth)]) / len(all_truth) acc = torch.tensor(acc).to(device) acc_list = [torch.zeros_like(acc) for _ in range(mpu.get_model_parallel_world_size())] torch.distributed.all_gather(acc_list, acc, mpu.get_model_parallel_group()) return acc_list[0].item(), all_truth, all_preds
def make_gpt2_dataloaders(args): # Input parameters. input_data_sizes_file = args.input_data_sizes_file seq_length = args.seq_length initial_seed = args.seed # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers def make_data_loader_(data_path): # Build the dataset. dataset = GPT2Dataset(data_path, input_data_sizes_file, seq_length, initial_seed) # Use a simple sampler with distributed batch sampler. sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) train = make_data_loader_(args.train_data_path) valid = make_data_loader_(args.val_data_path) test = make_data_loader_(args.test_data_path) args.do_train = False args.do_valid = False args.do_test = False if train is not None: args.do_train = True if valid is not None: args.do_valid = True if test is not None: args.do_test = True # Tokenizer. tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=args.cache_dir) eod_token = tokenizer.encoder['<|endoftext|>'] num_tokens = eod_token + 1 return (train, valid, test), num_tokens, eod_token
def __init__(self, args, tokenizer, max_seq_length, bert_prob=1.0, gap_sentence_prob=0.0, gpt_infill_prob=0.5, gpt_min_ratio=0.5, bert_ratio=0.15, gap_sentence_ratio=0.15, average_block_length=3, max_block_length=40, block_mask_prob=0.0, context_mask_ratio=0.0, context_mask_range=3, short_seq_prob=0.0, single_span_prob=0.0, block_position_encoding=True, encoder_decoder=False, shuffle_blocks=True, sentinel_token=False, task_mask=False, random_position=False, masked_lm=False): self.eod_token = args.eod_token self.tokenizer = tokenizer self.count = 0 self.max_seq_length = max_seq_length self.rank = mpu.get_data_parallel_rank() self.world_size = mpu.get_data_parallel_world_size() # self.rank = 0 # self.world_size = 1 assert 0.0 <= bert_prob <= 1.0 self.bert_prob = bert_prob self.gap_sentence_prob = gap_sentence_prob self.gpt_prob = 1 - bert_prob - gap_sentence_prob assert self.gpt_prob >= -1e-10 self.infill_prob = gpt_infill_prob self.gpt_min_ratio = gpt_min_ratio self.bert_ratio = bert_ratio self.gap_sentence_ratio = gap_sentence_ratio self.block_length_distribution = [poisson.pmf(i, average_block_length) for i in range(1, max_block_length)] self.block_mask_prob = block_mask_prob self.context_mask_ratio = context_mask_ratio self.context_mask_range = context_mask_range self.short_seq_prob = short_seq_prob self.single_span_prob = single_span_prob self.block_position_encoding = block_position_encoding self.encoder_decoder = encoder_decoder self.shuffle_blocks = shuffle_blocks self.sentinel_token = sentinel_token self.generation_mask = 'gMASK' if task_mask else 'MASK' self.generation_mask = self.tokenizer.get_command(self.generation_mask).Id self.gap_sentence_mask = 'sMASK' if task_mask else 'MASK' self.gap_sentence_mask = self.tokenizer.get_command(self.gap_sentence_mask).Id self.random_position = random_position self.masked_lm = masked_lm print_rank_0( f"BERT prob {self.bert_prob}, gap sent prob {self.gap_sentence_prob}, GPT prob {self.gpt_prob}, infill prob {self.infill_prob}") print_rank_0( f"generation min ratio {self.gpt_min_ratio}, block ratio {self.bert_ratio}, gap sent ratio {self.gap_sentence_ratio}") print_rank_0(f"block length distribution {self.block_length_distribution}") print_rank_0(f"block mask prob {self.block_mask_prob}, context mask ratio {self.context_mask_ratio}")
def test_initialize_model_parallel(model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing initialize_model_parallel with size {} ...'.format( model_parallel_size)) model_parallel_size_ = min(model_parallel_size, torch.distributed.get_world_size()) assert not mpu.model_parallel_is_initialized() mpu.initialize_model_parallel(model_parallel_size_) assert mpu.model_parallel_is_initialized() # Checks. def check(group, world_size, rank): assert world_size == torch.distributed.get_world_size(group=group) assert rank == torch.distributed.get_rank(group=group) # Model parallel. world_size = model_parallel_size_ rank = torch.distributed.get_rank() % model_parallel_size_ assert world_size == mpu.get_model_parallel_world_size() assert rank == mpu.get_model_parallel_rank() check(mpu.get_model_parallel_group(), world_size, rank) # Data parallel. world_size = torch.distributed.get_world_size() // model_parallel_size_ rank = torch.distributed.get_rank() // model_parallel_size assert world_size == mpu.get_data_parallel_world_size() assert rank == mpu.get_data_parallel_rank() check(mpu.get_data_parallel_group(), world_size, rank) # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def make_data_loader(dataset): """Buld dataloader given an input dataset.""" if dataset is None: return None args = get_args() # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers # Use a simple sampler with distributed batch sampler. sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
def main(): """Main training program.""" # Disable CuDNN. torch.backends.cudnn.enabled = False # Timer. timers = Timers() # Arguments. args = get_args() # Pytorch distributed. initialize_distributed(args) # Random seeds for reproducability. set_random_seed(args.seed) # get the tokenizer tokenizer = GPT2Tokenizer( os.path.join(args.tokenizer_path, 'vocab.json'), os.path.join(args.tokenizer_path, 'chinese_vocab.model')) # load data test_dataloader, test_dataset = load_data(args, 'test', tokenizer, 1) # Set an arbitrary positive integer since the optimizer and the scheduler will not be used when do eval. args.train_iters = 1 # Model model, _, _ = setup_model_and_optimizer(args) device = torch.cuda.current_device() # give a time stemp to the model cur_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) results_dir = os.path.join(args.results_dir, "{}-{}".format(args.model_name, cur_time)) if torch.distributed.get_rank() == 0: os.makedirs(results_dir, exist_ok=True) model.eval() all_sids = [] all_cids = [] all_losses = [] with torch.no_grad(): for batch, no_model_batch in tqdm( test_dataloader, desc="Evaluating", disable=(torch.distributed.get_rank() != 0)): for k in batch: batch[k] = batch[k].to(device) for k in no_model_batch: no_model_batch[k] = no_model_batch[k].to(device) output = model(**batch) losses = mpu.vocab_parallel_cross_entropy( output.contiguous().float(), no_model_batch["labels"]) loss_mask = no_model_batch["loss_mask"] loss = torch.sum(losses * loss_mask, dim=-1) / loss_mask.sum(dim=-1) loss_tensor_list = [ torch.zeros_like(loss).to(device) for _ in range(mpu.get_data_parallel_world_size()) ] torch.distributed.all_gather(loss_tensor_list, loss.data, group=mpu.get_data_parallel_group()) all_losses.extend(loss_tensor_list) sids = no_model_batch["sids"] sid_tensor_list = [ torch.zeros_like(sids) for _ in range(mpu.get_data_parallel_world_size()) ] torch.distributed.all_gather(sid_tensor_list, sids.data, group=mpu.get_data_parallel_group()) all_sids.extend(sid_tensor_list) cids = no_model_batch["cids"] cid_tensor_list = [ torch.zeros_like(cids) for _ in range(mpu.get_data_parallel_world_size()) ] torch.distributed.all_gather(cid_tensor_list, cids.data, group=mpu.get_data_parallel_group()) all_cids.extend(cid_tensor_list) if torch.distributed.get_rank() == 0: all_losses = torch.stack(all_losses).view(-1).cpu().detach().numpy() all_sids = torch.stack(all_sids).view(-1).cpu().detach().numpy() all_cids = torch.stack(all_cids).view(-1).cpu().detach().numpy() truth_labels = test_dataset.truth_labels preds = [[] for _ in truth_labels] for sid, cid, loss in zip(all_sids, all_cids, all_losses): preds[sid].append((cid, loss)) preds = [min(p, key=lambda x: x[1])[0] for p in preds if len(p) > 0] yprint("Acc: {}".format( sum([int(p == l) for p, l in zip(preds, truth_labels)]) / len(truth_labels))) with open(os.path.join(results_dir, "zero-shot_result.txt"), "w") as f: f.write("Acc: {}\n".format( sum([int(p == l) for p, l in zip(preds, truth_labels)]) / len(truth_labels))) torch.distributed.barrier()
def finetune(args, train_valid_datasets_provider, model_kwargs, forward_step=finetune_forward_step, end_of_epoch_callback_provider=None): """Main finetune function used across all tasks.""" global tokenizer timers = Timers() tokenizer = prepare_tokenizer(args) pretrain_glm.tokenizer = tokenizer if args.save: args.save = os.path.join(args.save, args.experiment_name) # Train and validation data loaders. timers('train/valid/test dataset/dataloder').start() train_dataloader, valid_dataloader = None, None train_block_dataloader, valid_block_dataloader = None, None if train_valid_datasets_provider is not None and args.epochs > 0: if mpu.get_model_parallel_rank() == 0: train_dataset, valid_dataset = train_valid_datasets_provider( args, tokenizer) train_dataloader, valid_dataloader = _build_train_valid_dataloaders( train_dataset, valid_dataset, args) if args.no_validation: valid_dataloader = None train_iters = torch.cuda.LongTensor([len(train_dataloader)]) else: train_iters = torch.cuda.LongTensor([0]) torch.distributed.broadcast(train_iters, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) if mpu.get_model_parallel_rank() != 0: args.train_iters_per_epoch = train_iters[0].item() args.train_iters = args.epochs * args.train_iters_per_epoch train_dataloader = FakeDataloader(args.train_iters_per_epoch) if args.no_validation: valid_dataloader = None else: valid_dataloader = FakeDataloader(None) if args.block_lm_ratio > 0.0: if mpu.get_model_parallel_rank() == 0: train_block_dataset, valid_block_dataset = train_valid_datasets_provider( args, tokenizer, pattern_text=True) train_block_dataloader = make_data_loader( train_block_dataset, tokenizer, args.batch_size * mpu.get_data_parallel_world_size(), args.train_iters, args, shuffle=True, block_collate=True) valid_block_dataloader = make_data_loader( valid_block_dataset, tokenizer, args.batch_size * mpu.get_data_parallel_world_size(), (args.train_iters // args.eval_interval + 1) * args.eval_iters, args, shuffle=True, block_collate=True) else: train_block_dataloader = FakeDataloader(args.train_iters) valid_block_dataloader = FakeDataloader(None) train_block_dataloader, valid_block_dataloader = iter( train_block_dataloader), iter(valid_block_dataloader) timers('train/valid/test dataset/dataloder').stop() # Build calback function. timers('callback function').start() end_of_epoch_callback, end_of_train_callback = None, None if end_of_epoch_callback_provider is not None: if train_valid_datasets_provider is not None and args.epochs > 0 and not args.no_validation: end_of_epoch_callback = end_of_epoch_callback_provider( args, tokenizer, is_test=False) end_of_train_callback = end_of_epoch_callback_provider(args, tokenizer, is_test=True) timers('callback function').stop() # Build model, optimizer and learning rate scheduler. timers('model and optimizer').start() model, optimizer, lr_scheduler = setup_model_and_optimizer( args, **model_kwargs) timers('model and optimizer').stop() # If pretrained checkpoint is provided and we have not trained for # any iteration (i.e., iteration is zero), then load the pretrained # checkpoint. timers('pretrained checkpoint').start() if args.load_pretrained is not None and not args.pretrained_bert: task_tokens = None if args.continuous_prompt and args.prompt_init: if mpu.get_model_parallel_rank() == 0: dataset = train_dataloader.dataset processor, pvp = dataset.processor, dataset.pvp task_tokens = [] for label in processor.get_labels(): verbalizer = pvp.verbalize(label)[0] verbalizer_ids = tokenizer.EncodeAsIds( verbalizer).tokenization task_tokens += verbalizer_ids print_rank_0("Task tokens: " + tokenizer.DecodeIds(task_tokens)) num_task_tokens = len(task_tokens) else: num_task_tokens, task_tokens = 0, [] num_task_tokens = torch.cuda.LongTensor([num_task_tokens]) torch.distributed.broadcast(num_task_tokens, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) num_task_tokens = num_task_tokens.item() if num_task_tokens > 0: if mpu.get_model_parallel_rank() == 0: task_tokens = torch.cuda.LongTensor(task_tokens) else: task_tokens = torch.empty( num_task_tokens, device=torch.cuda.current_device(), dtype=torch.long) torch.distributed.broadcast( task_tokens, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) task_tokens = task_tokens.tolist() with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"), timeout=-1): load_pretrained(model, args.load_pretrained, args, task_tokens=task_tokens) # This is critical when only model is loaded. We should make sure # master parameters are also updated. if args.fp16 and optimizer is not None: if args.deepspeed: optimizer.refresh_fp32_params() else: optimizer._model_params_to_master_params() if args.load is not None: with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"), timeout=-1): load_checkpoint(model, optimizer, lr_scheduler, args, no_deepspeed=args.no_deepspeed_load) # This is critical when only model is loaded. We should make sure # master parameters are also updated. if args.fp16 and optimizer is not None: if args.deepspeed: optimizer.refresh_fp32_params() else: optimizer._model_params_to_master_params() torch.distributed.barrier() timers('pretrained checkpoint').stop() args.iteration = 0 summary_writer = None if torch.distributed.get_rank() == 0: args.log_dir = get_log_dir(base=args.summary_dir, name=args.experiment_name) if os.path.exists(os.path.join(args.log_dir, "test_results.json") ) and args.load is None and not args.overwrite: raise ValueError( "Output directory ({}) already exists and is not empty.". format(args.log_dir)) summary_writer = get_sample_writer(log_dir=args.log_dir, iteration=args.iteration) print_and_save_args(args, verbose=True, log_dir=args.log_dir) # Print setup timing. print_rank_0('done with setups ...') timers.log([ 'train/valid/test dataset/dataloder', 'callback function', 'model and optimizer', 'pretrained checkpoint' ]) print_rank_0('training ...') # Finetune the model. score_dict = None if train_dataloader is not None and args.epochs > 0: if args.block_lm_ratio > 0.0: forward_step = mix_forward_step best_iteration = _train(model, optimizer, lr_scheduler, forward_step, (train_dataloader, train_block_dataloader), (valid_dataloader, valid_block_dataloader), end_of_epoch_callback, args, timers, summary_writer=summary_writer) if end_of_train_callback is not None and best_iteration is not None: with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"), timeout=-1): args.load = os.path.join(args.save, "best") load_checkpoint(model, optimizer, lr_scheduler, args, no_load_optim=True, no_deepspeed=True) args.load = None torch.distributed.barrier() if end_of_train_callback is not None: score_dict = end_of_train_callback(model, epoch=-1, output_predictions=True) # Or just evaluate. else: if end_of_train_callback is not None: print_rank_0('evaluation only mode, setting epoch to -1') score_dict = end_of_train_callback(model, epoch=-1, output_predictions=True) if score_dict is not None and torch.distributed.get_rank() == 0: score_dict.update({"type": "test"}) with open(os.path.join(args.log_dir, "test_results.json"), "w") as output: output.write(json.dumps(score_dict) + "\n") print_rank_0('done :-)')
def evaluate(self, model, dataloader, example_dict, args): model.eval() store = torch.distributed.TCPStore(args.master_ip, 18931 + random.randint(0, 10000), mpu.get_data_parallel_world_size(), torch.distributed.get_rank() == 0, datetime.timedelta(seconds=30)) print_rank_0("Distributed store created") with torch.no_grad(): for idx, data in enumerate(dataloader): tokens, attention_mask, position_ids = process_batch( data, args) src_tokens = tokens batch_size = tokens.size(0) mask_positions = [] current_mask = [] for text in tokens.tolist(): mask_positions.append([ i for i, x in enumerate(text) if x == self.mask_token ]) current_mask.append(0) # print(self.tokenizer.DecodeIds(text)) # print(mask_positions[-1]) counter = 0 done = [False] * batch_size while counter < args.tgt_seq_length: if counter == 0: # print(tokens) # print(position_ids) next_token_logits, *mems = model(tokens, position_ids, attention_mask, return_memory=True) next_token_logits = next_token_logits[:, -1] position_ids = tokens.new_ones(batch_size, 2, 1) for i, text in enumerate(tokens.tolist()): mask_pos = mask_positions[i][current_mask[i]] position_ids[i, 0] = mask_pos tokens = tokens.new_zeros(batch_size, 0) attention_mask = tokens.new_zeros(batch_size) else: position_ids[:, 1] = position_ids[:, 1] + 1 last_token = tokens[:, -1:] next_token_logits, *mems = model(last_token, position_ids, attention_mask, *mems, return_memory=True) next_token_logits = next_token_logits[:, -1] next_token_scores = F.log_softmax(next_token_logits, dim=-1) next_token_scores = self.processors( tokens, next_token_scores) next_tokens = next_token_scores.max(dim=-1)[1] # print(self.tokenizer.DecodeIds(next_tokens.tolist())) for i, next_token in enumerate(next_tokens.tolist()): if next_token == self.end_token: if current_mask[i] + 1 < len(mask_positions[i]): current_mask[i] += 1 next_tokens[i] = self.start_token position_ids[i, 0] = mask_positions[i][ current_mask[i]] position_ids[i, 1] = 0 else: done[i] = True if done[i]: next_tokens[i] = self.pad_token if all(done): break tokens = torch.cat( [tokens, next_tokens.unsqueeze(-1)], dim=-1) counter += 1 predictions = [] for i, text in enumerate(tokens.tolist()): text = [ token for token in text if token not in [self.end_token, self.pad_token] ] blanks = [[]] for token in text: if token == self.start_token: blanks.append([]) else: blanks[-1].append(token) output_tokens = [] current_blank = 0 for token in src_tokens[i].tolist(): if token == self.mask_token: if current_blank < len(blanks): output_tokens += blanks[current_blank] current_blank += 1 else: if token not in [self.pad_token]: output_tokens.append(token) text = self.tokenizer.DecodeIds(output_tokens[:-1]) text = blanklm_fix_tokenization(text) predictions.append(text) # print(text) uid_list = data['uid'] if isinstance(uid_list, torch.Tensor): uid_list = uid_list.cpu().numpy().tolist() for uid, prediction in zip(uid_list, predictions): store.set(uid, prediction) if (idx + 1) % args.log_interval == 0: print_rank_0(f"Iteration {idx + 1} / {len(dataloader)}") model.train() torch.distributed.barrier() print_rank_0("Evaluation completed") predictions, examples = [], [] for uid, example in example_dict.items(): predictions.append(store.get(uid).decode('utf-8')) examples.append(example) torch.distributed.barrier() return predictions, [], examples
def forward_step(data_iterator, model, args, timers, mems): """Forward step.""" # Get the batch. timers('batch generator').start() timers('data loader').start() rand = random.Random(args.iteration * mpu.get_data_parallel_world_size() + mpu.get_data_parallel_rank()) if data_iterator[1] and rand.random() < args.multi_task_ratio: data = next(data_iterator[1]) if data_iterator[1] else None data["mode"] = "multi-task" else: data = next(data_iterator[0]) if data_iterator[0] else None # print_rank_0("data iterator") timers('data loader').stop() tokens, labels, loss_mask, attention_mask, position_ids = get_batch( data, args) timers('batch generator').stop() # print_rank_0("get batch") def print_masked_text(batch_id): block_position_ids = position_ids[:, 1] position_ids_ = position_ids[:, 0] sep = attention_mask.item() if torch.numel( attention_mask) == 1 else attention_mask[batch_id].item() text, last_segment = "", [] for i, token_id in enumerate(tokens[batch_id, :sep].tolist()): token = tokenizer.IdToToken(token_id) if token.startswith('[MASK') or token.endswith('MASK]'): if last_segment: text += tokenizer.DecodeIds(last_segment) last_segment = [] text += f" [{position_ids_[batch_id, i].item()}, {token}]" else: last_segment.append(token_id) if last_segment: text += tokenizer.DecodeIds(last_segment) print(text.encode('utf-8')) last_index = None for i in range(sep, tokens.size(1)): if tokenizer.IdToToken( tokens[batch_id, i].item()).startswith("<|startofpiece"): if last_index is not None: print( tokenizer.DecodeIds( tokens[batch_id, last_index:i].tolist()).encode('utf-8'), "|", tokenizer.DecodeIds( labels[batch_id, last_index:i].tolist()).encode('utf-8'), position_ids_[batch_id, last_index:i].tolist(), block_position_ids[batch_id, last_index:i].tolist()) last_index = i if last_index is not None: print( tokenizer.DecodeIds( tokens[batch_id, last_index:].tolist()).encode('utf-8'), "|", tokenizer.DecodeIds( labels[batch_id, last_index:].tolist()).encode('utf-8'), position_ids_[batch_id, last_index:].tolist(), block_position_ids[batch_id, last_index:].tolist()) if data is not None and "mode" in data: mode = data['mode'] else: mode = 'bert' logits, *mems = model(tokens, position_ids, attention_mask, *mems) losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels) loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) if loss_mask.sum().item() > 0: loss = loss / loss_mask.sum() return loss, mems, mode
def evaluate(self, model, dataloader, example_dict, args): """Calculate correct over total answers and return prediction if the `output_predictions` is true.""" model.eval() store = torch.distributed.TCPStore(args.master_ip, 18931 + random.randint(0, 10000), mpu.get_data_parallel_world_size(), torch.distributed.get_rank() == 0, datetime.timedelta(seconds=30)) print_rank_0("Distributed store created") with torch.no_grad(): # For all the batches in the dataset. for idx, data in enumerate(dataloader): tokens, attention_mask, position_ids = process_batch( data, args) batch_size = tokens.size(0) beam_scorer = BeamSearchScorer( batch_size=batch_size, max_length=args.out_seq_length, num_beams=args.num_beams, device=tokens.device, length_penalty=args.length_penalty, do_early_stopping=False, ) beam_scores = torch.zeros((batch_size, args.num_beams), dtype=torch.float, device=tokens.device) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * args.num_beams, )) # Run the model forward. counter = 0 while counter < args.tgt_seq_length: if counter == 0: next_token_logits, *mems = model(tokens, position_ids, attention_mask, return_memory=True) seq_length = next_token_logits.size(1) next_token_logits = next_token_logits[:, -1] next_token_logits = next_token_logits.unsqueeze( 1).repeat(1, args.num_beams, 1).view(batch_size * args.num_beams, -1) mems = [ mem.unsqueeze(1).repeat( 1, args.num_beams, 1, 1).view(batch_size * args.num_beams, seq_length, -1) for mem in mems ] position_ids = tokens.new_ones(batch_size, args.num_beams, 2, 1) for i, text in enumerate(tokens.tolist()): mask_pos = text.index(self.mask_token) position_ids[i, :, 0] = mask_pos position_ids = position_ids.reshape( batch_size * args.num_beams, 2, 1) tokens = tokens.new_zeros(batch_size * args.num_beams, 0) attention_mask = tokens.new_zeros( [batch_size * args.num_beams]) else: if not args.no_block_position: position_ids[:, 1] = counter + 1 last_token = tokens[:, -1:] next_token_logits, *mems = model(last_token, position_ids, attention_mask, *mems, return_memory=True) next_token_logits = next_token_logits[:, -1] next_token_scores = F.log_softmax(next_token_logits, dim=-1) next_token_scores = self.processors( tokens, next_token_scores) next_token_scores = next_token_scores + beam_scores[:, None].expand_as( next_token_scores) vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view( batch_size, args.num_beams * vocab_size) probs = F.softmax(next_token_scores, dim=-1) if args.select_topk: _, next_tokens = torch.topk(probs, k=2 * args.num_beams, dim=-1, largest=True) else: next_tokens = torch.multinomial(probs, num_samples=2 * args.num_beams) next_token_scores = torch.gather(next_token_scores, -1, next_tokens) next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) next_tokens = torch.gather(next_tokens, -1, _indices) next_indices = next_tokens // vocab_size next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( tokens, next_token_scores, next_tokens, next_indices, eos_token_id=self.end_token, pad_token_id=self.pad_token) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] beam_next_tokens = beam_next_tokens.unsqueeze(-1) tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1) mems = [mem[beam_idx] for mem in mems] if mems else [] if beam_scorer.is_done: break counter += 1 tokens, _ = beam_scorer.finalize(tokens, beam_scores, next_tokens, next_indices, eos_token_id=self.end_token, pad_token_id=self.pad_token) predictions = [] for text in tokens.tolist(): text = [ token for token in text if token not in [self.end_token, self.pad_token] ] text = self.tokenizer.DecodeIds(text) predictions.append(text) uid_list = data['uid'] if isinstance(uid_list, torch.Tensor): uid_list = uid_list.cpu().numpy().tolist() for uid, prediction in zip(uid_list, predictions): store.set(uid, prediction) if (idx + 1) % args.log_interval == 0: print_rank_0(f"Iteration {idx + 1} / {len(dataloader)}") model.train() torch.distributed.barrier() print_rank_0("Evaluation completed") predictions, examples = [], [] for uid, example in example_dict.items(): predictions.append(store.get(uid).decode('utf-8')) examples.append(example) torch.distributed.barrier() return predictions, [], examples
def build_train_valid_test_data_iterators( build_train_valid_test_datasets_provider, args): """XXX""" (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) print_rank_0('> building train, validation, and test datasets ...') # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: # Rank, size, and global batch size. data_parallel_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * data_parallel_size # Number of train/valid/test samples. train_iters = args.train_iters eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters test_iters = args.eval_iters train_val_test_num_samples = [train_iters * global_batch_size, eval_iters * global_batch_size, test_iters * global_batch_size] print_rank_0(' > datasets target sizes (minimum size):') print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) # Build the datasets. train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider( train_val_test_num_samples) # Build dataloders. train_dataloader = make_data_loader(train_ds) valid_dataloader = make_data_loader(valid_ds) test_dataloader = make_data_loader(test_ds) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and args.train_iters > 0 do_valid = valid_dataloader is not None and args.eval_iters > 0 do_test = test_dataloader is not None and args.eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. flags = torch.cuda.LongTensor( [int(do_train), int(do_valid), int(do_test)]) else: flags = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(flags, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) args.do_train = flags[0].item() args.do_valid = flags[1].item() args.do_test = flags[2].item() # Shift the start iterations. if train_dataloader is not None: train_dataloader.batch_sampler.start_iter = args.iteration % \ len(train_dataloader) print_rank_0('setting training data start iteration to {}'. format(train_dataloader.batch_sampler.start_iter)) if valid_dataloader is not None: start_iter_val = (args.iteration // args.eval_interval) * \ args.eval_iters valid_dataloader.batch_sampler.start_iter = start_iter_val % \ len(valid_dataloader) print_rank_0('setting validation data start iteration to {}'. format(valid_dataloader.batch_sampler.start_iter)) # Build iterators. if train_dataloader is not None: train_data_iterator = iter(train_dataloader) else: train_data_iterator = None if valid_dataloader is not None: valid_data_iterator = iter(valid_dataloader) else: valid_data_iterator = None if test_dataloader is not None: test_data_iterator = iter(test_dataloader) else: test_data_iterator = None return train_data_iterator, valid_data_iterator, test_data_iterator
def main(): """Main training program.""" # Disable CuDNN. torch.backends.cudnn.enabled = False # Timer. timers = Timers() # Arguments. args = get_args() # Pytorch distributed. initialize_distributed(args) # Random seeds for reproducability. set_random_seed(args.seed) # get the tokenizer tokenizer = GPT2Tokenizer(os.path.join(args.tokenizer_path, 'vocab.json'), os.path.join(args.tokenizer_path, 'chinese_vocab.model')) # load train data if args.do_train: train_dataloader, _ = load_data(args, 'train', tokenizer, 1) dev_dataloader, dev_dataset = load_data(args, 'dev', tokenizer, 1) with open(args.deepspeed_config, "r") as f: deepspeed_conf = json.load(f) epoch = args.epoch grad_acc = deepspeed_conf["gradient_accumulation_steps"] args.train_iters = len(train_dataloader) * epoch / grad_acc # Model, optimizer, and learning rate. # TODO: maybe need to reinitialize optimizer elif args.do_eval: # Set an arbitrary positive integer since the optimizer and the scheduler will not be used when do eval. args.train_iters = 1 model, optimizer, lr_scheduler = setup_model_and_optimizer_C(args) device = torch.cuda.current_device() # give a time stemp to the model cur_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) results_dir = os.path.join(args.results_dir, "{}-{}".format(args.model_name, cur_time)) os.makedirs(results_dir, exist_ok=True) if args.do_train and torch.distributed.get_rank() == 0: with open(os.path.join(results_dir, "train_log.txt"), "w") as f: f.write("Train losses:\n") with open(os.path.join(results_dir, "dev_log.txt"), "w") as f: f.write("Dev accs:\n") torch.distributed.barrier() if args.do_train: # cand_ids = torch.tensor(dev_dataset.cand_ids).to(device) total_loss, logging_loss, best_acc = 0.0, 0.0, 0.0 global_step, total_step, best_step = 0, 0, 0 for e in range(epoch): model.train() for batch, no_model_batch in tqdm(train_dataloader, disable=(torch.distributed.get_rank() != 0)): for k in batch: batch[k] = batch[k].to(device) for k in no_model_batch: no_model_batch[k] = no_model_batch[k].to(device) output = model(**batch) # get the loss of the last token output = torch.sum(output * no_model_batch["loss_mask"].unsqueeze(-1), 1) / torch.sum(no_model_batch["loss_mask"], -1).unsqueeze(-1) # get the label of the last token # labels = no_model_batch["labels"].float() labels = no_model_batch["truth"].float() # labels = (torch.sum(labels * no_model_batch["loss_mask"], 1) / torch.sum(no_model_batch["loss_mask"], -1)).long() # cross_entropy loss # losses = mpu.vocab_parallel_cross_entropy(output.unsqueeze(1).contiguous().float(), labels.unsqueeze(1)) losses = CrossEntropyLoss(output.unsqueeze(1).contiguous().float(), labels.unsqueeze(1)) loss = torch.mean(losses) model.backward(loss) model.step() torch.distributed.all_reduce(loss.data, group=mpu.get_data_parallel_group()) loss.data = loss.data / mpu.get_data_parallel_world_size() total_loss += loss.item() / grad_acc if total_step % grad_acc == 0: global_step += 1 if global_step != 0 and global_step % args.log_interval == 0: # logging if torch.distributed.get_rank() == 0: train_log = "Epoch {}, global step {}, total step {}, train lm loss: {}".format(e, global_step, epoch * len(train_dataloader), (total_loss - logging_loss) / args.log_interval) yprint(train_log) with open(os.path.join(results_dir, "train_log.txt"), "a") as f: f.write(train_log + "\n") logging_loss = total_loss if global_step != 0 and global_step % args.eval_interval == 0: # evaluate on the dev acc, _, _ = evaluate_tnews(args, model, dev_dataloader, device, mode="dev") dev_results_dir = os.path.join(results_dir, "dev_step-{}".format(global_step)) if acc > best_acc: best_acc = acc best_step = global_step if torch.distributed.get_rank() == 0: # we will only write the log file once dev_log = "Epoch: {}, Global step: {}, Acc: {}".format(e, global_step, acc) yprint(dev_log) os.makedirs(dev_results_dir, exist_ok=True) with open(os.path.join(dev_results_dir, "dev_result.txt"), "w") as f: f.write(dev_log + "\n") with open(os.path.join(results_dir, "dev_log.txt"), "a") as f: f.write(dev_log + "\n") torch.distributed.barrier() args.save = dev_results_dir save_checkpoint(global_step, model, optimizer, lr_scheduler, args) total_step += 1 with open(os.path.join(dev_results_dir, "dev_log.txt"), "a") as f: f.write("Best acc: {} Best step: {}\n".format(best_acc, best_step)) if args.do_eval: # evaluate on the test test_dataloader, test_dataset = load_data(args, 'test', tokenizer, 1) cand_ids = torch.tensor(test_dataset.cand_ids).to(device) if args.do_train: # if do training, then evaluate the one with the max acc on dev set. eval_ckpt_path = os.path.join(results_dir, "dev_step-{}".format(best_step)) args.load = eval_ckpt_path else: # if only do eval, then evaluate the one specified by the user. args.load = args.eval_ckpt_path load_checkpoint(model=model, optimizer=None, lr_scheduler=None, args=args) acc, _, _ = evaluate(args, model, test_dataloader, cand_ids, device, mode="test") if torch.distributed.get_rank() == 0: eval_log = "Checkpoint from {}: Acc: {}".format(args.load, acc) yprint(eval_log) with open(os.path.join(results_dir, "eval_log"), "w") as f: f.write(eval_log + "\n") torch.distributed.barrier()
def load_tnews_data(data_path, data_type, tokenizer, few_shot=False): args = get_args() filename = os.path.join(data_path, data_type+'.json') objs = [] with open(filename) as fin: for line in fin: objs.append(json.loads(line.strip())) pad_id = tokenizer.encoder['<pad>'] args.eod_token = tokenizer.encoder['<eod>'] labels = [] label_map = {} label_reverse = {} with open(os.path.join(data_path, 'labels.json')) as fin: for i, line in enumerate(fin): obj = json.loads(line.strip()) labels.append(obj['label_desc']) label_map[obj['label_desc']] = i label_reverse[obj['label']] = obj['label_desc'] all_tokens = [] all_masks = [] all_labels = [] for _, obj in enumerate(objs): sentence = obj['sentence'] tokenized_sentence = tokenizer.encode(sentence)[:args.seq_length-20] obj['label_desc'] = label_reverse[obj['label']] if few_shot: cur_labels = random.sample(labels, 3) while obj['label_desc'] in cur_labels: cur_labels = random.sample(labels, 3) cur_labels.append(obj['label_desc']) cur_label = cur_labels.index(obj['label_desc']) assert cur_label != -1 else: cur_labels = labels cur_label = label_map[obj['label_desc']] all_labels.append(cur_label) for _, label in enumerate(cur_labels): prompt = "这是关于{}的文章:".format(label) prompt_tokens = tokenizer.encode(prompt) prompt_len = len(prompt_tokens) tokens = prompt_tokens + tokenized_sentence second_mask = [0] * (args.seq_length-1) for idx in range(prompt_len-1, len(tokens)-1): second_mask[idx] = 1 all_masks.append(second_mask) token_length = len(tokens) assert token_length < args.seq_length tokens.extend([pad_id] * (args.seq_length - token_length)) all_tokens.append(tokens) all_tokens = torch.tensor(all_tokens, dtype=torch.long) all_masks = torch.tensor(all_masks, dtype=torch.float) dataset = TensorDataset(all_tokens, all_masks) # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True), all_labels
def load_ocnli_data(data_path, data_type, tokenizer): args = get_args() filename = os.path.join(data_path, data_type+'.json') objs = [] with open(filename) as fin: for line in fin: objs.append(json.loads(line.strip())) pad_id = tokenizer.encoder['<pad>'] args.eod_token = tokenizer.encoder['<eod>'] all_tokens_1 = [] all_masks_1 = [] all_tokens_2 = [] all_masks_2 = [] all_tokens_3 = [] all_masks_3 = [] all_labels = [] for obj in objs: if obj['label'] == '-': continue prompt = "{}?对,".format(obj['sentence1']) prompt_tokens = tokenizer.encode(prompt) prompt_len = len(prompt_tokens) tokens = prompt_tokens + tokenizer.encode(obj['sentence2']) second_mask = [0] * (args.seq_length-1) for idx in range(prompt_len-1, len(tokens)-1): second_mask[idx] = 1 all_masks_1.append(second_mask) token_length = len(tokens) assert token_length < args.seq_length tokens.extend([pad_id] * (args.seq_length - token_length)) all_tokens_1.append(tokens) prompt = "{}?错,".format(obj['sentence1']) prompt_tokens = tokenizer.encode(prompt) prompt_len = len(prompt_tokens) tokens = prompt_tokens + tokenizer.encode(obj['sentence2']) second_mask = [0] * (args.seq_length-1) for idx in range(prompt_len-1, len(tokens)-1): second_mask[idx] = 1 all_masks_2.append(second_mask) token_length = len(tokens) assert token_length < args.seq_length tokens.extend([pad_id] * (args.seq_length - token_length)) all_tokens_2.append(tokens) prompt = "{}?也许,".format(obj['sentence1']) prompt_tokens = tokenizer.encode(prompt) prompt_len = len(prompt_tokens) tokens = prompt_tokens + tokenizer.encode(obj['sentence2']) second_mask = [0] * (args.seq_length-1) for idx in range(prompt_len-1, len(tokens)-1): second_mask[idx] = 1 all_masks_3.append(second_mask) token_length = len(tokens) assert token_length < args.seq_length tokens.extend([pad_id] * (args.seq_length - token_length)) all_tokens_3.append(tokens) if obj['label'] == 'entailment': all_labels.append([0]) elif obj['label'] == 'contradiction': all_labels.append([1]) else: all_labels.append([2]) all_tokens_1 = torch.tensor(all_tokens_1, dtype=torch.long) all_masks_1 = torch.tensor(all_masks_1, dtype=torch.float) all_tokens_2 = torch.tensor(all_tokens_2, dtype=torch.long) all_masks_2 = torch.tensor(all_masks_2, dtype=torch.float) all_tokens_3 = torch.tensor(all_tokens_3, dtype=torch.long) all_masks_3 = torch.tensor(all_masks_3, dtype=torch.float) all_labels = torch.tensor(all_labels, dtype=torch.long) dataset = TensorDataset(all_tokens_1, all_masks_1, all_tokens_2, all_masks_2, all_tokens_3, all_masks_3, all_labels) # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers # Use a random sampler with distributed batch sampler. if data_type == 'train': sampler = RandomSampler(dataset) else: sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)