def get_train_dataset(args, index, finetune=False, shuffle=True): assert not finetune, "finetune not supported" i = 0 dataloaders = {} datalengths = [] batchs_per_dataset = [] batch_mapping = {} config = args.config dataset_paths = config["data"]["datasets"] dataset_flags = config["data"]["flags"] # Pretraining dataset if dataset_flags.get("pretrain_dataset", False): pretrain_type = dataset_flags.get("pretrain_type") if pretrain_type == "wiki_bc": # Load Wiki Dataset wiki_pretrain_dataset = PreTrainingDataset( args.tokenizer, os.path.join(args.data_path_prefix, dataset_paths['wiki_pretrain_dataset']), args.logger, args.max_seq_length, index, PretrainDataType.NUMPY, args.max_predictions_per_seq) datalengths.append(len(wiki_pretrain_dataset)) dataloaders[i] = get_dataloader(args, wiki_pretrain_dataset) batch_mapping[i] = PretrainBatch batchs_per_dataset.append( get_effective_batch(args, len(wiki_pretrain_dataset))) i += 1 bc_pretrain_dataset = PreTrainingDataset( args.tokenizer, os.path.join(args.data_path_prefix, dataset_paths['bc_pretrain_dataset']), args.logger, args.max_seq_length, index, PretrainDataType.NUMPY, args.max_predictions_per_seq) datalengths.append(len(bc_pretrain_dataset)) dataloaders[i] = get_dataloader(args, bc_pretrain_dataset) batch_mapping[i] = PretrainBatch batchs_per_dataset.append( get_effective_batch(args, len(bc_pretrain_dataset))) i += 1 dataset_batches = [] for i, batch_count in enumerate(batchs_per_dataset): dataset_batches.extend([i] * batch_count) # shuffle if shuffle: random.shuffle(dataset_batches) dataset_picker = [] for dataset_batch_type in dataset_batches: dataset_picker.extend([dataset_batch_type] * args.gradient_accumulation_steps * args.refresh_bucket_size) return dataset_picker, dataloaders, sum(datalengths)
def pretrain_validation(args, index, model): config = args.config logger = args.logger model.eval() dataset = PreTrainingDataset( args.tokenizer, os.path.join(args.data_path_prefix, config['validation']['path']), args.logger, args.max_seq_length, index, PretrainDataType.VALIDATION, args.max_predictions_per_seq) data_batches = get_dataloader(args, dataset, eval_set=True) eval_loss = 0 nb_eval_steps = 0 for batch in tqdm(data_batches): batch = tuple(t.to(args.device) for t in batch) tmp_eval_loss = model.network(batch, log=False) dist.reduce(tmp_eval_loss, 0) # Reduce to get the loss from all the GPU's tmp_eval_loss = tmp_eval_loss / dist.get_world_size() eval_loss += tmp_eval_loss.mean().item() nb_eval_steps += 1 eval_loss = eval_loss / nb_eval_steps logger.info(f"Validation Loss for epoch {index + 1} is: {eval_loss}") if (not args.no_cuda and dist.get_rank() == 0) or (args.no_cuda and args.local_rank == -1): args.summary_writer.add_scalar(f'Validation/Loss', eval_loss, index + 1) return
def get_shard(self, index, shuffle=True): datalengths = [] batches_per_dataset = [] for i, dataset_path in enumerate(self.dataset_paths): pretrain_dataset = PreTrainingDataset( tokenizer=self.tokenizer, folder=dataset_path, logger=self.logger, max_seq_length=self.max_seq_length, index=index, data_type=PretrainDataType.NUMPY, max_predictions_per_seq=self.max_predictions_per_seq) datalengths.append(len(pretrain_dataset)) batches_per_dataset.append( self._get_effective_batch(len(pretrain_dataset))) self.dataloaders[i] = self._get_dataloader(pretrain_dataset) dataset_batches = [] for i, batch_count in enumerate(batches_per_dataset): dataset_batches.extend([i] * batch_count) # shuffle if shuffle: random.shuffle(dataset_batches) self.dataset_iterator = [] for dataset_batch_type in dataset_batches: self.dataset_iterator.extend([dataset_batch_type] * self.gradient_accumulation_steps * self.refresh_bucket_size) if self.async_dataloading: self.async_worker = AsyncWorker(self.dataloaders, self.dataset_iterator) self.async_worker.start() return self.dataset_iterator, sum(datalengths)
def get_train_dataset(args, index, finetune=False, shuffle=True): i = 0 dataloaders = {} datalengths = [] batchs_per_dataset = [] batch_mapping = {} config = args.config dataset_paths = config["data"]["datasets"] dataset_flags = config["data"]["flags"] if finetune: qp_finetune_dataset = QAFinetuningDataset( args.tokenizer, dataset_paths["qp_finetuning_dataset"], args.logger, args.max_seq_length) datalengths.append(len(qp_finetune_dataset)) dataloaders[i] = get_dataloader(args, qp_finetune_dataset) batch_mapping[i] = QABatch batchs_per_dataset.append( get_effective_batch(args, len(qp_finetune_dataset))) i += 1 else: # QP dataset if dataset_flags.get("qp_dataset", False): qp_dataset = QADataset(args.tokenizer, dataset_paths["qp_dataset"], args.logger, args.max_seq_length, index) datalengths.append(len(qp_dataset)) dataloaders[i] = get_dataloader(args, qp_dataset) batch_mapping[i] = QABatch batchs_per_dataset.append( get_effective_batch(args, len(qp_dataset))) i += 1 # Pretraining dataset if dataset_flags.get("pretrain_dataset", False): pretrain_type = dataset_flags.get("pretrain_type") # CLEAN BODY Data Load if pretrain_type == "clean_body": cb_pretrain_dataset = PreTrainingDataset( args.tokenizer, dataset_paths['cb_pretrain_dataset'], args.logger, args.max_seq_length, index, PretrainDataType.NUMPY) datalengths.append(len(cb_pretrain_dataset)) dataloaders[i] = get_dataloader(args, cb_pretrain_dataset) batch_mapping[i] = PretrainBatch batchs_per_dataset.append( get_effective_batch(args, len(cb_pretrain_dataset))) i += 1 elif pretrain_type == "wiki_bc": # Load Wiki Dataset wiki_pretrain_dataset = PreTrainingDataset( args.tokenizer, dataset_paths['wiki_pretrain_dataset'], args.logger, args.max_seq_length, index, PretrainDataType.NUMPY, args.max_predictions_per_seq) datalengths.append(len(wiki_pretrain_dataset)) dataloaders[i] = get_dataloader(args, wiki_pretrain_dataset) batch_mapping[i] = PretrainBatch batchs_per_dataset.append( get_effective_batch(args, len(wiki_pretrain_dataset))) i += 1 bc_pretrain_dataset = PreTrainingDataset( args.tokenizer, dataset_paths['bc_pretrain_dataset'], args.logger, args.max_seq_length, index, PretrainDataType.NUMPY, args.max_predictions_per_seq) datalengths.append(len(bc_pretrain_dataset)) dataloaders[i] = get_dataloader(args, bc_pretrain_dataset) batch_mapping[i] = PretrainBatch batchs_per_dataset.append( get_effective_batch(args, len(bc_pretrain_dataset))) i += 1 # Ranking Dataset if dataset_flags.get("ranking_dataset", False): ranking_dataset = RankingDataset(args.tokenizer, dataset_paths['ranking_dataset'], args.logger, args.max_seq_length, index, args.fp16) datalengths.append(len(ranking_dataset)) dataloaders[i] = get_dataloader(args, ranking_dataset) batch_mapping[i] = RankingBatch batchs_per_dataset.append( get_effective_batch(args, len(ranking_dataset))) i += 1 dataset_batches = [] for i, batch_count in enumerate(batchs_per_dataset): dataset_batches.extend([i] * batch_count) # shuffle if shuffle: random.shuffle(dataset_batches) dataset_picker = [] for dataset_batch_type in dataset_batches: dataset_picker.extend([dataset_batch_type] * args.gradient_accumulation_steps * args.refresh_bucket_size) return dataset_picker, dataloaders, sum(datalengths)