コード例 #1
0
def get_dataset(args, datasets, data_dir, tokenizer, split_name):
    datasets = datasets.split(',')
    dataset_dict = None
    dataset_name = 'binary'
    for dataset in datasets:
        dataset_name += f'_{dataset}'
        dataset_dict_curr = util.read_squad(f'{data_dir}/{dataset}', label=1)
        dataset_dict = util.merge(dataset_dict, dataset_dict_curr)
    data_encodings = read_and_process(args, tokenizer, dataset_dict, data_dir,
                                      dataset_name, split_name)
    return util.QADomainDataset(data_encodings,
                                train=(split_name == 'train')), dataset_dict
コード例 #2
0
def get_dataset(args, datasets, data_dir, tokenizer, split_name):
    datasets = datasets.split(',')
    dataset_dict = None
    dataset_name = 'individual'
    label = 3 if 'val' in split_name else 0
    for dataset in datasets:
        dataset_name += f'_{dataset}'
        dataset_dict_curr = util.read_squad(f'{data_dir}/{dataset}',
                                            label=label)
        dataset_dict = util.merge(dataset_dict, dataset_dict_curr)
        label += 1
    data_encodings = read_and_process(args, tokenizer, dataset_dict, data_dir,
                                      dataset_name, split_name)
    return util.QADomainDataset(data_encodings,
                                train=(split_name == 'train')), dataset_dict
コード例 #3
0
def get_train_dataset(args,
                      target_data_dir,
                      target_dataset,
                      tokenizer,
                      split_name,
                      source_data_dir=None,
                      source_dataset=None):
    dataset_dict_source = None
    dataset_dict_target = None
    data_encodings_source = None
    source_dataset_name = 'individual'
    target_dataset_name = 'individual'
    if source_data_dir is not None and source_dataset is not None:
        datasets = source_dataset.split(',')
        label = 0
        for dataset in datasets:
            source_dataset_name += f'_{dataset}'
            dataset_dict_curr = util.read_squad(f'{source_data_dir}/{dataset}',
                                                label=label)
            dataset_dict_source = util.merge(dataset_dict_source,
                                             dataset_dict_curr)
            label += 1
        data_encodings_source = read_and_process(args, tokenizer,
                                                 dataset_dict_source,
                                                 source_data_dir,
                                                 source_dataset_name,
                                                 split_name)
    label = 3
    datasets = target_dataset.split(',')
    for dataset in datasets:
        target_dataset_name = f'_{dataset}'
        # dataset_dict_curr = util.read_squad(f'{target_data_dir}/{dataset}', label=1)
        dataset_dict_curr = xuran_perform_eda.perform_eda(
            f'{target_data_dir}/{dataset}',
            dataset,
            train_fraction=1,
            label=label)
        dataset_dict_target = util.merge(dataset_dict_target,
                                         dataset_dict_curr)
        label += 1
    data_encodings_target = read_and_process(args, tokenizer,
                                             dataset_dict_target,
                                             target_data_dir,
                                             target_dataset_name, split_name)
    dataset_dict = util.merge(dataset_dict_source, dataset_dict_target)
    data_encodings = util.merge(data_encodings_source, data_encodings_target)
    return util.QADomainDataset(data_encodings,
                                train=(split_name == 'train')), dataset_dict