Пример #1
0
def get_train_dataset(args, index, finetune=False, shuffle=True):
    assert not finetune, "finetune not supported"
    i = 0
    dataloaders = {}
    datalengths = []
    batchs_per_dataset = []
    batch_mapping = {}

    config = args.config
    dataset_paths = config["data"]["datasets"]
    dataset_flags = config["data"]["flags"]

    # Pretraining dataset
    if dataset_flags.get("pretrain_dataset", False):
        pretrain_type = dataset_flags.get("pretrain_type")

        if pretrain_type == "wiki_bc":
            # Load Wiki Dataset
            wiki_pretrain_dataset = PreTrainingDataset(
                args.tokenizer,
                os.path.join(args.data_path_prefix,
                             dataset_paths['wiki_pretrain_dataset']),
                args.logger, args.max_seq_length, index,
                PretrainDataType.NUMPY, args.max_predictions_per_seq)
            datalengths.append(len(wiki_pretrain_dataset))
            dataloaders[i] = get_dataloader(args, wiki_pretrain_dataset)
            batch_mapping[i] = PretrainBatch
            batchs_per_dataset.append(
                get_effective_batch(args, len(wiki_pretrain_dataset)))
            i += 1

            bc_pretrain_dataset = PreTrainingDataset(
                args.tokenizer,
                os.path.join(args.data_path_prefix,
                             dataset_paths['bc_pretrain_dataset']),
                args.logger, args.max_seq_length, index,
                PretrainDataType.NUMPY, args.max_predictions_per_seq)
            datalengths.append(len(bc_pretrain_dataset))
            dataloaders[i] = get_dataloader(args, bc_pretrain_dataset)
            batch_mapping[i] = PretrainBatch
            batchs_per_dataset.append(
                get_effective_batch(args, len(bc_pretrain_dataset)))
            i += 1

    dataset_batches = []
    for i, batch_count in enumerate(batchs_per_dataset):
        dataset_batches.extend([i] * batch_count)

    # shuffle
    if shuffle:
        random.shuffle(dataset_batches)

    dataset_picker = []
    for dataset_batch_type in dataset_batches:
        dataset_picker.extend([dataset_batch_type] *
                              args.gradient_accumulation_steps *
                              args.refresh_bucket_size)

    return dataset_picker, dataloaders, sum(datalengths)
Пример #2
0
def pretrain_validation(args, index, model):
    config = args.config
    logger = args.logger

    model.eval()
    dataset = PreTrainingDataset(
        args.tokenizer,
        os.path.join(args.data_path_prefix, config['validation']['path']),
        args.logger, args.max_seq_length, index, PretrainDataType.VALIDATION,
        args.max_predictions_per_seq)
    data_batches = get_dataloader(args, dataset, eval_set=True)
    eval_loss = 0
    nb_eval_steps = 0
    for batch in tqdm(data_batches):
        batch = tuple(t.to(args.device) for t in batch)
        tmp_eval_loss = model.network(batch, log=False)
        dist.reduce(tmp_eval_loss, 0)
        # Reduce to get the loss from all the GPU's
        tmp_eval_loss = tmp_eval_loss / dist.get_world_size()
        eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
    eval_loss = eval_loss / nb_eval_steps
    logger.info(f"Validation Loss for epoch {index + 1} is: {eval_loss}")
    if (not args.no_cuda
            and dist.get_rank() == 0) or (args.no_cuda
                                          and args.local_rank == -1):
        args.summary_writer.add_scalar(f'Validation/Loss', eval_loss,
                                       index + 1)
    return
    def get_shard(self, index, shuffle=True):
        datalengths = []
        batches_per_dataset = []

        for i, dataset_path in enumerate(self.dataset_paths):
            pretrain_dataset = PreTrainingDataset(
                tokenizer=self.tokenizer,
                folder=dataset_path,
                logger=self.logger,
                max_seq_length=self.max_seq_length,
                index=index,
                data_type=PretrainDataType.NUMPY,
                max_predictions_per_seq=self.max_predictions_per_seq)

            datalengths.append(len(pretrain_dataset))
            batches_per_dataset.append(
                self._get_effective_batch(len(pretrain_dataset)))
            self.dataloaders[i] = self._get_dataloader(pretrain_dataset)

        dataset_batches = []
        for i, batch_count in enumerate(batches_per_dataset):
            dataset_batches.extend([i] * batch_count)

        # shuffle
        if shuffle:
            random.shuffle(dataset_batches)

        self.dataset_iterator = []
        for dataset_batch_type in dataset_batches:
            self.dataset_iterator.extend([dataset_batch_type] *
                                         self.gradient_accumulation_steps *
                                         self.refresh_bucket_size)

        if self.async_dataloading:
            self.async_worker = AsyncWorker(self.dataloaders,
                                            self.dataset_iterator)
            self.async_worker.start()

        return self.dataset_iterator, sum(datalengths)
Пример #4
0
def get_train_dataset(args, index, finetune=False, shuffle=True):
    i = 0
    dataloaders = {}
    datalengths = []
    batchs_per_dataset = []
    batch_mapping = {}

    config = args.config
    dataset_paths = config["data"]["datasets"]
    dataset_flags = config["data"]["flags"]

    if finetune:
        qp_finetune_dataset = QAFinetuningDataset(
            args.tokenizer, dataset_paths["qp_finetuning_dataset"],
            args.logger, args.max_seq_length)
        datalengths.append(len(qp_finetune_dataset))
        dataloaders[i] = get_dataloader(args, qp_finetune_dataset)
        batch_mapping[i] = QABatch
        batchs_per_dataset.append(
            get_effective_batch(args, len(qp_finetune_dataset)))
        i += 1

    else:
        # QP dataset
        if dataset_flags.get("qp_dataset", False):
            qp_dataset = QADataset(args.tokenizer, dataset_paths["qp_dataset"],
                                   args.logger, args.max_seq_length, index)
            datalengths.append(len(qp_dataset))
            dataloaders[i] = get_dataloader(args, qp_dataset)
            batch_mapping[i] = QABatch
            batchs_per_dataset.append(
                get_effective_batch(args, len(qp_dataset)))
            i += 1

        # Pretraining dataset
        if dataset_flags.get("pretrain_dataset", False):
            pretrain_type = dataset_flags.get("pretrain_type")

            # CLEAN BODY Data Load
            if pretrain_type == "clean_body":
                cb_pretrain_dataset = PreTrainingDataset(
                    args.tokenizer, dataset_paths['cb_pretrain_dataset'],
                    args.logger, args.max_seq_length, index,
                    PretrainDataType.NUMPY)
                datalengths.append(len(cb_pretrain_dataset))
                dataloaders[i] = get_dataloader(args, cb_pretrain_dataset)
                batch_mapping[i] = PretrainBatch
                batchs_per_dataset.append(
                    get_effective_batch(args, len(cb_pretrain_dataset)))
                i += 1

            elif pretrain_type == "wiki_bc":
                # Load Wiki Dataset
                wiki_pretrain_dataset = PreTrainingDataset(
                    args.tokenizer, dataset_paths['wiki_pretrain_dataset'],
                    args.logger, args.max_seq_length, index,
                    PretrainDataType.NUMPY, args.max_predictions_per_seq)
                datalengths.append(len(wiki_pretrain_dataset))
                dataloaders[i] = get_dataloader(args, wiki_pretrain_dataset)
                batch_mapping[i] = PretrainBatch
                batchs_per_dataset.append(
                    get_effective_batch(args, len(wiki_pretrain_dataset)))
                i += 1

                bc_pretrain_dataset = PreTrainingDataset(
                    args.tokenizer, dataset_paths['bc_pretrain_dataset'],
                    args.logger, args.max_seq_length, index,
                    PretrainDataType.NUMPY, args.max_predictions_per_seq)
                datalengths.append(len(bc_pretrain_dataset))
                dataloaders[i] = get_dataloader(args, bc_pretrain_dataset)
                batch_mapping[i] = PretrainBatch
                batchs_per_dataset.append(
                    get_effective_batch(args, len(bc_pretrain_dataset)))
                i += 1

        # Ranking Dataset
        if dataset_flags.get("ranking_dataset", False):
            ranking_dataset = RankingDataset(args.tokenizer,
                                             dataset_paths['ranking_dataset'],
                                             args.logger, args.max_seq_length,
                                             index, args.fp16)
            datalengths.append(len(ranking_dataset))
            dataloaders[i] = get_dataloader(args, ranking_dataset)
            batch_mapping[i] = RankingBatch
            batchs_per_dataset.append(
                get_effective_batch(args, len(ranking_dataset)))
            i += 1

    dataset_batches = []
    for i, batch_count in enumerate(batchs_per_dataset):
        dataset_batches.extend([i] * batch_count)

    # shuffle
    if shuffle:
        random.shuffle(dataset_batches)

    dataset_picker = []
    for dataset_batch_type in dataset_batches:
        dataset_picker.extend([dataset_batch_type] *
                              args.gradient_accumulation_steps *
                              args.refresh_bucket_size)

    return dataset_picker, dataloaders, sum(datalengths)