コード例 #1
0
                    help="Max sequence length for target text. Sequences will be truncated or padded to this length")

if __name__ == "__main__":
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    with open(args.config_path, 'r') as f:
        config = json.load(f)
    model = BertAbsSum(args.bert_model, config['decoder_config'], device)
    model.load_state_dict(torch.load(args.model_path))
    model.to(device)

    processor = LCSTSProcessor()
    tokenizer = BertTokenizer.from_pretrained(os.path.join(args.bert_model, 'vocab.txt'))
    test_examples = processor.get_examples(args.eval_path)
    test_features = convert_examples_to_features(test_examples, args.max_src_len, args.max_tgt_len, tokenizer)
    test_data = create_dataset(test_features)
    test_sampler = RandomSampler(test_data)
    test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=BATCH_SIZE, drop_last=True)
    logger.info('Loading complete. Writing results to %s' % (args.result_path))

    model.eval()
    f_log = open(os.path.join(args.result_path, 'log.txt'), 'w', encoding='utf-8')
    # f_hyp = open(os.path.join(args.result_path, 'hyp.txt'), 'w', encoding='utf-8')
    # f_ref = open(os.path.join(args.result_path, 'ref.txt'), 'w', encoding='utf-8')
    hyp_list = []
    ref_list = []
    for batch in tqdm(test_dataloader, desc="Iteration"):
        batch = tuple(t.to(device) for t in batch)
        pred, _ = model.beam_decode(batch[0], batch[1], 3, 3)
        src, tgt = batch[0], batch[2]
コード例 #2
0
ファイル: dataset.py プロジェクト: cscyuge/seq2seq
def build_dataset_eval(config):
    abbrs_path = './data/abbrs-all-uncased.pkl'
    txt_path = './data/test(2030).txt'
    with open(abbrs_path, 'rb') as f:
        abbrs = pickle.load(f)
    src_txt, tar_1_txt, tar_2_txt = get_test_src_tar_txt(
        txt_path, config.tokenizer)
    seq_srcs = []
    dcmn_srcs = []
    dcmn_labels = []
    key_choices = []

    with open('./data/test_mask_step2_2030.pkl', 'rb') as f:
        mask_step1 = pickle.load(f)

    k_as = []
    for i, (src, tar) in enumerate(zip(src_txt, tar_1_txt)):
        src = word_tokenize(src, config.tokenizer)
        tar = word_tokenize(tar, config.tokenizer)
        sentences, labels, _src, key_ans, _, _tar = get_dcmn_data_from_gt(
            src,
            tar,
            abbrs,
            max_pad_length=config.num_choices + 2,
            max_dcmn_seq_length=config.max_seq_length,
            tokenizer=config.tokenizer)
        k_as.append(key_ans)

    for i, (sts, masks, k_a) in enumerate(zip(src_txt, mask_step1, k_as)):
        sts = word_tokenize(sts, config.tokenizer)
        assert len(sts) == len(masks)
        sentences, labels, _src, k_cs = get_dcmn_data_from_step1(
            sts,
            masks,
            k_a,
            abbrs,
            max_pad_length=config.num_choices + 2,
            max_dcmn_seq_length=config.max_seq_length,
            tokenizer=config.tokenizer)
        dcmn_srcs.extend(sentences)
        dcmn_labels.extend(labels)
        if len(sentences) != _src.count('[UNK]'):
            print(i, sts)
        seq_srcs.append(_src)
        key_choices.append(k_cs)

    for i in range(len(seq_srcs)):
        seq_srcs[i] = '[CLS] ' + seq_srcs[i] + ' [SEP]'

    cudics = pickle.load(open('./data/test_cudics.pkl', 'rb'))
    seq_tars = pickle.load(open('./data/test_tars.pkl', 'rb'))

    q_id = [i + 1 for i in range(len(dcmn_labels))]
    article = [u[0] for u in dcmn_srcs]
    question = [u[1] for u in dcmn_srcs]
    cts = []
    for i in range(config.num_choices):
        cts.append([u[i + 2] for u in dcmn_srcs])

    examples = [
        SwagExample(
            swag_id=s5,
            context_sentence=s1,
            start_ending=s2,
            endings=s3,
            label=s4,
        ) for i, (
            s1, s2, *s3, s4,
            s5) in enumerate(zip(article, question, *cts, dcmn_labels, q_id))
    ]

    features = convert_examples_to_features(examples, config.tokenizer,
                                            config.max_seq_length)
    input_ids = select_field(features, 'input_ids')
    input_mask = select_field(features, 'input_mask')
    segment_ids = select_field(features, 'segment_ids')
    doc_len = select_field(features, 'doc_len')
    ques_len = select_field(features, 'ques_len')
    option_len = select_field(features, 'option_len')
    labels = [f.label for f in features]

    dcmn_contents = []
    for i in range(len(input_ids)):
        dcmn_contents.append(
            (input_ids[i], input_mask[i], segment_ids[i], doc_len[i],
             ques_len[i], option_len[i], labels[i]))

    seq_contents = []
    for i in range(len(seq_srcs)):
        seq_contents.append(
            (seq_srcs[i], seq_tars[i], cudics[i], key_choices[i]))

    return seq_contents, dcmn_contents
コード例 #3
0
    # train data preprocess
    processor = LCSTSProcessor()
    tokenizer = BertTokenizer.from_pretrained(
        os.path.join(args.bert_model, 'vocab.txt'))
    logger.info('Loading train examples...')
    if not os.path.exists(os.path.join(args.data_dir, TRAIN_FILE)):
        raise ValueError(f'train.csv does not exist.')
    train_examples = processor.get_examples(
        os.path.join(args.data_dir, TRAIN_FILE))
    num_train_optimization_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs
    logger.info('Converting train examples to features...')
    train_features = convert_examples_to_features(train_examples,
                                                  args.max_src_len,
                                                  args.max_tgt_len, tokenizer)
    example = train_examples[0]
    example_feature = train_features[0]
    logger.info("*** Example ***")
    logger.info("guid: %s" % (example.guid))
    logger.info("src text: %s" % example.src)
    logger.info("src_ids: %s" %
                " ".join([str(x) for x in example_feature.src_ids]))
    logger.info("src_mask: %s" %
                " ".join([str(x) for x in example_feature.src_mask]))
    logger.info("tgt text: %s" % example.tgt)
    logger.info("tgt_ids: %s" %
                " ".join([str(x) for x in example_feature.tgt_ids]))
    logger.info("tgt_mask: %s" %
                " ".join([str(x) for x in example_feature.tgt_mask]))
コード例 #4
0
ファイル: dataset.py プロジェクト: cscyuge/medlane-code
def build_dataset(config):
    abbrs_path = './data/abbrs-all-cased.pkl'
    # txt_path = './data/train(12809).txt'
    txt_path = os.path.join(config.data_dir, config.train_file)
    with open(abbrs_path, 'rb') as f:
        abbrs = pickle.load(f)
    src_txt, tar_1_txt, tar_2_txt = get_train_src_tar_txt(txt_path)
    # src_txt = src_txt[:100]
    # tar_1_txt = tar_1_txt[:100]
    # tar_2_txt = tar_2_txt[:100]

    seq_srcs = []
    seq_tars = []
    dcmn_srcs = []
    dcmn_labels = []
    key_choices = []

    for i, (src, tar) in enumerate(zip(src_txt, tar_1_txt)):
        src = word_tokenize(src, config.tokenizer)
        tar = word_tokenize(tar, config.tokenizer)
        sentences, labels, _src, key_ans, k_c, _tar = get_dcmn_data_from_gt(
            src,
            tar,
            abbrs,
            max_pad_length=config.num_choices + 2,
            max_dcmn_seq_length=config.max_seq_length,
            tokenizer=config.tokenizer)
        if len(sentences) != _src.count('[UNK]'):
            print(i, src, len(sentences))
        dcmn_srcs.extend(sentences)
        dcmn_labels.extend(labels)
        seq_srcs.append(_src)
        seq_tars.append(_tar)
        key_choices.append(k_c)

    for i in range(len(seq_srcs)):
        seq_srcs[i] = '[CLS] ' + seq_srcs[i] + ' [SEP]'

    q_id = [i + 1 for i in range(len(dcmn_labels))]
    article = [u[0] for u in dcmn_srcs]
    question = [u[1] for u in dcmn_srcs]
    cts = []
    for i in range(config.num_choices):
        cts.append([u[i + 2] for u in dcmn_srcs])

    examples = [
        SwagExample(
            swag_id=s5,
            context_sentence=s1,
            start_ending=s2,
            endings=s3,
            label=s4,
        ) for i, (
            s1, s2, *s3, s4,
            s5) in enumerate(zip(article, question, *cts, dcmn_labels, q_id))
    ]

    features = convert_examples_to_features(examples, config.tokenizer,
                                            config.max_seq_length)
    input_ids = select_field(features, 'input_ids')
    input_mask = select_field(features, 'input_mask')
    segment_ids = select_field(features, 'segment_ids')
    doc_len = select_field(features, 'doc_len')
    ques_len = select_field(features, 'ques_len')
    option_len = select_field(features, 'option_len')
    labels = [f.label for f in features]

    dcmn_contents = []
    for i in range(len(input_ids)):
        dcmn_contents.append(
            (input_ids[i], input_mask[i], segment_ids[i], doc_len[i],
             ques_len[i], option_len[i], labels[i]))

    seq_contents = []
    for i in range(len(seq_srcs)):
        seq_contents.append((seq_srcs[i], seq_tars[i], key_choices[i]))

    return seq_contents, dcmn_contents