예제 #1
0
def get_dataloader(data_dir, data_file, num_choices, tokenizer, max_seq_length, batch_size, dg):
    examples = read_swag_examples(os.path.join(data_dir, data_file), max_pad_length=num_choices + 2, dg=dg)

    features = convert_examples_to_features(examples, tokenizer, max_seq_length)
    all_input_ids = torch.LongTensor(select_field(features, 'input_ids'))
    all_input_mask = torch.LongTensor(select_field(features, 'input_mask'))
    all_segment_ids = torch.LongTensor(select_field(features, 'segment_ids'))
    all_doc_len = torch.LongTensor(select_field(features, 'doc_len'))
    all_ques_len = torch.LongTensor(select_field(features, 'ques_len'))
    all_option_len = torch.LongTensor(select_field(features, 'option_len'))
    all_label = torch.LongTensor([f.label for f in features])

    data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label, all_doc_len, all_ques_len,
                         all_option_len)

    sampler = SequentialSampler(data)
    dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size)
    return dataloader, len(examples)
예제 #2
0
파일: dataset.py 프로젝트: cscyuge/seq2seq
def build_dataset_eval(config):
    abbrs_path = './data/abbrs-all-uncased.pkl'
    txt_path = './data/test(2030).txt'
    with open(abbrs_path, 'rb') as f:
        abbrs = pickle.load(f)
    src_txt, tar_1_txt, tar_2_txt = get_test_src_tar_txt(
        txt_path, config.tokenizer)
    seq_srcs = []
    dcmn_srcs = []
    dcmn_labels = []
    key_choices = []

    with open('./data/test_mask_step2_2030.pkl', 'rb') as f:
        mask_step1 = pickle.load(f)

    k_as = []
    for i, (src, tar) in enumerate(zip(src_txt, tar_1_txt)):
        src = word_tokenize(src, config.tokenizer)
        tar = word_tokenize(tar, config.tokenizer)
        sentences, labels, _src, key_ans, _, _tar = get_dcmn_data_from_gt(
            src,
            tar,
            abbrs,
            max_pad_length=config.num_choices + 2,
            max_dcmn_seq_length=config.max_seq_length,
            tokenizer=config.tokenizer)
        k_as.append(key_ans)

    for i, (sts, masks, k_a) in enumerate(zip(src_txt, mask_step1, k_as)):
        sts = word_tokenize(sts, config.tokenizer)
        assert len(sts) == len(masks)
        sentences, labels, _src, k_cs = get_dcmn_data_from_step1(
            sts,
            masks,
            k_a,
            abbrs,
            max_pad_length=config.num_choices + 2,
            max_dcmn_seq_length=config.max_seq_length,
            tokenizer=config.tokenizer)
        dcmn_srcs.extend(sentences)
        dcmn_labels.extend(labels)
        if len(sentences) != _src.count('[UNK]'):
            print(i, sts)
        seq_srcs.append(_src)
        key_choices.append(k_cs)

    for i in range(len(seq_srcs)):
        seq_srcs[i] = '[CLS] ' + seq_srcs[i] + ' [SEP]'

    cudics = pickle.load(open('./data/test_cudics.pkl', 'rb'))
    seq_tars = pickle.load(open('./data/test_tars.pkl', 'rb'))

    q_id = [i + 1 for i in range(len(dcmn_labels))]
    article = [u[0] for u in dcmn_srcs]
    question = [u[1] for u in dcmn_srcs]
    cts = []
    for i in range(config.num_choices):
        cts.append([u[i + 2] for u in dcmn_srcs])

    examples = [
        SwagExample(
            swag_id=s5,
            context_sentence=s1,
            start_ending=s2,
            endings=s3,
            label=s4,
        ) for i, (
            s1, s2, *s3, s4,
            s5) in enumerate(zip(article, question, *cts, dcmn_labels, q_id))
    ]

    features = convert_examples_to_features(examples, config.tokenizer,
                                            config.max_seq_length)
    input_ids = select_field(features, 'input_ids')
    input_mask = select_field(features, 'input_mask')
    segment_ids = select_field(features, 'segment_ids')
    doc_len = select_field(features, 'doc_len')
    ques_len = select_field(features, 'ques_len')
    option_len = select_field(features, 'option_len')
    labels = [f.label for f in features]

    dcmn_contents = []
    for i in range(len(input_ids)):
        dcmn_contents.append(
            (input_ids[i], input_mask[i], segment_ids[i], doc_len[i],
             ques_len[i], option_len[i], labels[i]))

    seq_contents = []
    for i in range(len(seq_srcs)):
        seq_contents.append(
            (seq_srcs[i], seq_tars[i], cudics[i], key_choices[i]))

    return seq_contents, dcmn_contents
예제 #3
0
def build_dataset(config):
    abbrs_path = './data/abbrs-all-cased.pkl'
    # txt_path = './data/train(12809).txt'
    txt_path = os.path.join(config.data_dir, config.train_file)
    with open(abbrs_path, 'rb') as f:
        abbrs = pickle.load(f)
    src_txt, tar_1_txt, tar_2_txt = get_train_src_tar_txt(txt_path)
    # src_txt = src_txt[:100]
    # tar_1_txt = tar_1_txt[:100]
    # tar_2_txt = tar_2_txt[:100]

    seq_srcs = []
    seq_tars = []
    dcmn_srcs = []
    dcmn_labels = []
    key_choices = []

    for i, (src, tar) in enumerate(zip(src_txt, tar_1_txt)):
        src = word_tokenize(src, config.tokenizer)
        tar = word_tokenize(tar, config.tokenizer)
        sentences, labels, _src, key_ans, k_c, _tar = get_dcmn_data_from_gt(
            src,
            tar,
            abbrs,
            max_pad_length=config.num_choices + 2,
            max_dcmn_seq_length=config.max_seq_length,
            tokenizer=config.tokenizer)
        if len(sentences) != _src.count('[UNK]'):
            print(i, src, len(sentences))
        dcmn_srcs.extend(sentences)
        dcmn_labels.extend(labels)
        seq_srcs.append(_src)
        seq_tars.append(_tar)
        key_choices.append(k_c)

    for i in range(len(seq_srcs)):
        seq_srcs[i] = '[CLS] ' + seq_srcs[i] + ' [SEP]'

    q_id = [i + 1 for i in range(len(dcmn_labels))]
    article = [u[0] for u in dcmn_srcs]
    question = [u[1] for u in dcmn_srcs]
    cts = []
    for i in range(config.num_choices):
        cts.append([u[i + 2] for u in dcmn_srcs])

    examples = [
        SwagExample(
            swag_id=s5,
            context_sentence=s1,
            start_ending=s2,
            endings=s3,
            label=s4,
        ) for i, (
            s1, s2, *s3, s4,
            s5) in enumerate(zip(article, question, *cts, dcmn_labels, q_id))
    ]

    features = convert_examples_to_features(examples, config.tokenizer,
                                            config.max_seq_length)
    input_ids = select_field(features, 'input_ids')
    input_mask = select_field(features, 'input_mask')
    segment_ids = select_field(features, 'segment_ids')
    doc_len = select_field(features, 'doc_len')
    ques_len = select_field(features, 'ques_len')
    option_len = select_field(features, 'option_len')
    labels = [f.label for f in features]

    dcmn_contents = []
    for i in range(len(input_ids)):
        dcmn_contents.append(
            (input_ids[i], input_mask[i], segment_ids[i], doc_len[i],
             ques_len[i], option_len[i], labels[i]))

    seq_contents = []
    for i in range(len(seq_srcs)):
        seq_contents.append((seq_srcs[i], seq_tars[i], key_choices[i]))

    return seq_contents, dcmn_contents