示例#1
0
def load_snli(path, files):
    loader = SNLILoader()
    ds_list = [loader.load(os.path.join(path, f)) for f in files]
    word_v = Vocabulary(min_freq=2)
    tag_v = Vocabulary(unknown=None, padding=None)
    for ds in ds_list:
        ds.apply(lambda x: [w.lower() for w in x['words1']],
                 new_field_name='words1')
        ds.apply(lambda x: [w.lower() for w in x['words2']],
                 new_field_name='words2')
    update_v(word_v, ds_list[0], 'words1')
    update_v(word_v, ds_list[0], 'words2')
    ds_list[0].apply(lambda x: tag_v.add_word(x['target']),
                     new_field_name=None)

    def process_data(ds):
        to_index(word_v, ds, 'words1', C.INPUTS(0))
        to_index(word_v, ds, 'words2', C.INPUTS(1))
        ds.apply(lambda x: tag_v.to_index(x['target']),
                 new_field_name=C.TARGET)
        ds.apply(lambda x: x[C.INPUTS(0)][:MAX_LEN],
                 new_field_name=C.INPUTS(0))
        ds.apply(lambda x: x[C.INPUTS(1)][:MAX_LEN],
                 new_field_name=C.INPUTS(1))
        ds.apply(lambda x: len(x[C.INPUTS(0)]), new_field_name=C.INPUT_LENS(0))
        ds.apply(lambda x: len(x[C.INPUTS(1)]), new_field_name=C.INPUT_LENS(1))
        ds.set_input(C.INPUTS(0), C.INPUTS(1), C.INPUT_LENS(0),
                     C.INPUT_LENS(1))
        ds.set_target(C.TARGET)

    for i in range(len(ds_list)):
        process_data(ds_list[i])
    return ds_list, word_v, tag_v
示例#2
0
    def test_import(self):
        import fastNLP
        from fastNLP.io import SNLILoader
        ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl',
                                  to_lower=True,
                                  get_index=True,
                                  seq_len_type='seq_len',
                                  extra_split=['-'])
        assert 'train' in ds.datasets
        assert len(ds.datasets) == 1
        assert len(ds.datasets['train']) == 3

        ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl',
                                  to_lower=True,
                                  get_index=True,
                                  seq_len_type='seq_len')
        assert 'train' in ds.datasets
        assert len(ds.datasets) == 1
        assert len(ds.datasets['train']) == 3
示例#3
0
 def read_snli():
     data_info = SNLILoader().process(
         paths='path/to/snli/data',
         to_lower=True,
         seq_len_type=None,
         bert_tokenizer=None,
         get_index=True,
         concat=False,
         extra_split=['/', '%', '-'],
     )
     return data_info
示例#4
0
arg = BERTConfig()

# set random seed
random.seed(arg.seed)
np.random.seed(arg.seed)
torch.manual_seed(arg.seed)

n_gpu = torch.cuda.device_count()
if n_gpu > 0:
    torch.cuda.manual_seed_all(arg.seed)

# load data set
if arg.task == 'snli':
    data_info = SNLILoader().process(
        paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type,
        bert_tokenizer=arg.bert_dir, cut_text=512,
        get_index=True, concat='bert',
    )
elif arg.task == 'rte':
    data_info = RTELoader().process(
        paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type,
        bert_tokenizer=arg.bert_dir, cut_text=512,
        get_index=True, concat='bert',
    )
elif arg.task == 'qnli':
    data_info = QNLILoader().process(
        paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type,
        bert_tokenizer=arg.bert_dir, cut_text=512,
        get_index=True, concat='bert',
    )
elif arg.task == 'mnli':
示例#5
0
 def test_SNLILoader(self):
     ds = SNLILoader().load('test/data_for_tests/sample_snli.jsonl')
     assert len(ds) == 3
示例#6
0
                                     get_index=True,
                                     concat=False,
                                     auto_pad_length=arg.max_len)
elif arg.dataset == 'rte':
    data_info = RTELoader().process(paths='path/to/rte/data',
                                    to_lower=True,
                                    seq_len_type=arg.seq_len_type,
                                    bert_tokenizer=None,
                                    get_index=True,
                                    concat=False,
                                    auto_pad_length=arg.max_len)
elif arg.dataset == 'snli':
    data_info = SNLILoader().process(paths='path/to/snli/data',
                                     to_lower=True,
                                     seq_len_type=arg.seq_len_type,
                                     bert_tokenizer=None,
                                     get_index=True,
                                     concat=False,
                                     auto_pad_length=arg.max_len)
elif arg.dataset == 'mnli':
    data_info = MNLILoader().process(paths='path/to/mnli/data',
                                     to_lower=True,
                                     seq_len_type=arg.seq_len_type,
                                     bert_tokenizer=None,
                                     get_index=True,
                                     concat=False,
                                     auto_pad_length=arg.max_len)
else:
    raise ValueError(
        f'now we only support [qnli,rte,snli,mnli] dataset for cntn model!')
示例#7
0
arg = ESIMConfig()

# set random seed
random.seed(arg.seed)
np.random.seed(arg.seed)
torch.manual_seed(arg.seed)

n_gpu = torch.cuda.device_count()
if n_gpu > 0:
    torch.cuda.manual_seed_all(arg.seed)

# load data set
if arg.task == 'snli':
    data_info = SNLILoader().process(
        paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type,
        get_index=True, concat=False,
    )
elif arg.task == 'rte':
    data_info = RTELoader().process(
        paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type,
        get_index=True, concat=False,
    )
elif arg.task == 'qnli':
    data_info = QNLILoader().process(
        paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type,
        get_index=True, concat=False,
    )
elif arg.task == 'mnli':
    data_info = MNLILoader().process(
        paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type,
        get_index=True, concat=False,