示例#1
0
def get_batch_trans(t_batch, is_test = True, n_ext=5):
    # print(t_batch)
    dataset = []
    for source, tgt in t_batch: #zip(raw_srcs, raw_tgts):
        source = [s.split(" ") for s in source]
        tgt =  [s.split(" ") for s in tgt] #[source[s] for s in tgt]
        # print(source, tgt)
        sent_labels = greedy_selection(source[:args_bert.max_src_nsents], tgt, n_ext)
        # print(sent_labels)
        if (args_bert.lower):
            source = [' '.join(s).lower().split() for s in source]
            tgt = [' '.join(s).lower().split() for s in tgt]
        b_data = bert.preprocess(source, tgt, sent_labels, use_bert_basic_tokenizer=args_bert.use_bert_basic_tokenizer,
                                is_test=is_test, order=True)
        src_subtoken_idxs, sent_labels, tgt_subtoken_idxs, segments_ids, cls_ids, src_txt, tgt_txt, ord_labels = b_data
        b_data_dict = {"src": src_subtoken_idxs, "tgt": tgt_subtoken_idxs,
                    "src_sent_labels": sent_labels, "segs": segments_ids, 'clss': cls_ids,
                    'src_txt': src_txt, "tgt_txt": tgt_txt, "ord_labels": ord_labels}
        dataset.append(b_data_dict)

    batch = []  
    for ex in dataset:
        ex = preprocess(args_train, ex, is_test, order=True)
        batch.append(ex)
    batch = models.data_loader.Batch(batch, device, is_test, order=True)
    return batch
示例#2
0
def extractor_batch(raw_srcs, raw_tgts, is_test = True, n_ext=3): #raw_sents is batch of articles (list of list of sentences)
    dataset = []
    for source, tgt in zip(raw_srcs, raw_tgts):
        source = [s.split(" ") for s in source]
        tgt = [s.split(" ") for s in tgt]
        sent_labels = greedy_selection(source[:args_bert.max_src_nsents], tgt, n_ext)
        # print(sent_labels)
        if (args_bert.lower):
            source = [' '.join(s).lower().split() for s in source]
            tgt = [' '.join(s).lower().split() for s in tgt]
        b_data = bert.preprocess(source, tgt, sent_labels, use_bert_basic_tokenizer=args_bert.use_bert_basic_tokenizer,
                                is_test=is_test)
        src_subtoken_idxs, sent_labels, tgt_subtoken_idxs, segments_ids, cls_ids, src_txt, tgt_txt = b_data
        b_data_dict = {"src": src_subtoken_idxs, "tgt": tgt_subtoken_idxs,
                    "src_sent_labels": sent_labels, "segs": segments_ids, 'clss': cls_ids,
                    'src_txt': src_txt, "tgt_txt": tgt_txt}
        dataset.append(b_data_dict)

    batch = []  
    for ex in dataset:
        ex = preprocess(args_train, ex, is_test)
        batch.append(ex)
    batch = models.data_loader.Batch(batch, device, is_test)
    step = -1 # so that no ROUGE calculation done
    return trainer.test_extract([batch], step, n_ext=n_ext), batch
示例#3
0
def _format_to_nnsum(params, sent_count=5):
    f, args, input_dir, abstracts_dir, label_dir = params
    #print(f)
    doc_id = f.split('/')[-1].split('.')[
        0]  #0000bf554ca24b0c72178403b54c0cca62d9faf8.story.json
    source, tgt = load_json(f, args.lower)
    if len(source) < 1 or len(tgt) < 1:
        return
    if (args.oracle_mode == 'greedy'):
        oracle_ids = greedy_selection(source, tgt, sent_count)
    elif (args.oracle_mode == 'combination'):
        oracle_ids = combination_selection(source, tgt, sent_count)
    '''we should filter the empty file here'''
    labels = [1 if idx in oracle_ids else 0 for idx in range(len(source))]
    label_str = {"id": doc_id, "labels": labels}
    label_file = label_dir / "{}.json".format(doc_id)
    label_file.write_text(json.dumps(label_str))

    inputs = [{"tokens": sent, "text": " ".join(sent)} for sent in source]
    entry = {"id": doc_id, "inputs": inputs}
    input_file = input_dir / "{}.json".format(doc_id)
    input_file.write_text(json.dumps(entry))

    lines = [" ".join(sent) for sent in tgt]
    target_str = "\n".join(lines)
    abstract_file = abstracts_dir / "{}.spl".format(doc_id)
    abstract_file.write_text(target_str)
    '''
示例#4
0
def _format_to_bert(params, sent_count=5):
    json_file, args, save_file = params
    if (os.path.exists(save_file)):
        logger.info('Ignore %s' % save_file)
        return

    bert = BertData(args)

    logger.info('Processing %s' % json_file)
    jobs = json.load(open(json_file))
    datasets = []
    for d in jobs:
        doc_id, source, tgt = d['docId'], d['src'], d['tgt']
        if (args.oracle_mode == 'greedy'):
            oracle_ids = greedy_selection(source, tgt, sent_count)
        elif (args.oracle_mode == 'combination'):
            oracle_ids = combination_selection(source, tgt, sent_count)
        #print(oracle_ids)
        b_data = bert.preprocess(source, tgt, oracle_ids)
        if (b_data is None):
            continue
        indexed_tokens, labels, segments_ids, cls_ids, src_txt, tgt_txt = b_data
        #print(labels)
        b_data_dict = {
            "doc_id": doc_id,
            "src": indexed_tokens,
            "labels": labels,
            "segs": segments_ids,
            'clss': cls_ids,
            'src_txt': src_txt,
            "tgt_txt": tgt_txt
        }
        datasets.append(b_data_dict)
    logger.info('Saving to %s' % save_file)
    torch.save(datasets, save_file)
    datasets = []
    gc.collect()
示例#5
0
def _format_to_jigsaw(
        params):  # 比bert data 的preprocess 就多了读存数据和lower以及判断是否为空的的utils
    corpus_type, json_file, args, save_file = params
    is_test = corpus_type == 'test'
    if (os.path.exists(save_file)):
        logger.info('Ignore %s' % save_file)
        return

    bert = JigsawData(args)

    logger.info('Processing %s' % json_file)
    jobs = json.load(open(json_file))  # 一个jobs 是一个shard 现在膨胀了times 倍
    if args.sample_near:
        datasets = []
        for d in jobs:
            source, tgt = d['src'], d['tgt']
            sent_labels = greedy_selection(source[:args.max_src_nsents], tgt,
                                           3)
            if (args.lower):
                source = [' '.join(s).lower().split() for s in source]
                tgt = [' '.join(s).lower().split() for s in tgt]
            b_data = bert.preprocess_jigsaw(
                source,
                tgt,
                sent_labels,
                use_bert_basic_tokenizer=args.use_bert_basic_tokenizer,
                is_test=is_test,
                times=args.times,
                unchange_prob=args.unchange_prob)
            if (b_data is None):
                continue
            src_subtoken_idxs, sent_labels, tgt_subtoken_idxs, segments_ids, cls_ids, src_txt, tgt_txt, poss, org_sent_labels = b_data[
                0]
            for i in range(0, args.times):
                src_subtoken_idxs_s, sent_labels_s, segments_ids_s, cls_ids_s, src_txt_s, poss_s, org_sent_labels_s = b_data[
                    i + 1]
                if args.keep_orgdata:
                    b_data_dict = {
                        'src_s': src_subtoken_idxs_s,
                        "tgt": tgt_subtoken_idxs,
                        "segs_s": segments_ids_s,
                        'clss_s': cls_ids_s,
                        "src_sent_labels_s": sent_labels_s,
                        "org_sent_labels_s": org_sent_labels_s,
                        'poss_s': poss_s,
                        'src_txt_s': src_txt_s,
                        "tgt_txt": tgt_txt,
                        'src_txt': src_txt,
                        "src": src_subtoken_idxs,
                        "segs": segments_ids,
                        'clss': cls_ids,
                        "src_sent_labels": sent_labels,
                        "org_sent_labels": org_sent_labels,
                        'poss': poss,
                    }
                else:
                    b_data_dict = {
                        'src_s': src_subtoken_idxs_s,
                        "tgt": tgt_subtoken_idxs,
                        "segs_s": segments_ids_s,
                        'clss_s': cls_ids_s,
                        "src_sent_labels_s": sent_labels_s,
                        "org_sent_labels_s": org_sent_labels_s,
                        'poss_s': poss_s,
                        'src_txt_s': src_txt_s,
                        "tgt_txt": tgt_txt,
                        'src_tex': src_txt,
                    }
                datasets.append(b_data_dict)
        logger.info('Processed instances %d' % len(datasets))
        logger.info('Saving to %s' % save_file)
        torch.save(datasets, save_file)
        datasets = []
    else:
        # datasets = []
        dataset_list = [[] for i in range(args.times)]
        for d in jobs:
            source, tgt = d['src'], d['tgt']
            sent_labels = greedy_selection(source[:args.max_src_nsents], tgt,
                                           3)

            if (args.lower):
                source = [' '.join(s).lower().split() for s in source]
                tgt = [' '.join(s).lower().split() for s in tgt]
            b_data = bert.preprocess_jigsaw(
                source,
                tgt,
                sent_labels,
                use_bert_basic_tokenizer=args.use_bert_basic_tokenizer,
                is_test=is_test,
                times=args.times,
                unchange_prob=args.unchange_prob)
            if (b_data is None):
                continue
            src_subtoken_idxs, sent_labels, tgt_subtoken_idxs, segments_ids, cls_ids, src_txt, tgt_txt, poss, org_sent_labels = b_data[
                0]
            for i in range(args.times):
                src_subtoken_idxs_s, sent_labels_s, segments_ids_s, cls_ids_s, src_txt_s, poss_s, org_sent_labels_s = b_data[
                    i + 1]
                if args.keep_orgdata:
                    b_data_dict = {
                        'src_s': src_subtoken_idxs_s,
                        "tgt": tgt_subtoken_idxs,
                        "segs_s": segments_ids_s,
                        'clss_s': cls_ids_s,
                        "src_sent_labels_s": sent_labels_s,
                        "org_sent_labels_s": org_sent_labels_s,
                        'poss_s': poss_s,
                        'src_txt_s': src_txt_s,
                        "tgt_txt": tgt_txt,
                        'src_txt': src_txt,
                        "src": src_subtoken_idxs,
                        "segs": segments_ids,
                        'clss': cls_ids,
                        "src_sent_labels": sent_labels,
                        "org_sent_labels": org_sent_labels,
                        'poss': poss,
                    }
                else:
                    b_data_dict = {
                        'src_s': src_subtoken_idxs_s,
                        "tgt": tgt_subtoken_idxs,
                        "segs_s": segments_ids_s,
                        'clss_s': cls_ids_s,
                        "src_sent_labels_s": sent_labels_s,
                        "org_sent_labels_s": org_sent_labels_s,
                        'poss_s': poss_s,
                        'src_txt_s': src_txt_s,
                        "tgt_txt": tgt_txt,
                        'src_tex': src_txt,
                    }
                dataset_list[i].append(b_data_dict)
        for j in range(args.times):
            logger.info('Processed instances %d' % len(dataset_list[j]))
            # Train 从0到143, valid 从0到6, test 从0到5
            # cnndm_sample.train.0.bert.pt cnndm.valid.6.bert.pt
            filename = save_file.split('/')[-1]
            save_file_list = []
            save_file_split = filename.split('.')
            save_file_list.extend(save_file_split[:2])
            # ../ jigsaw_data / cnndm.train.93.bert.pt
            if save_file_split[0] == 'cnndm_sample':
                total = 1
            else:
                if save_file_split[1] == 'train':
                    total = 144
                elif save_file_split[1] == 'valid':
                    total = 7
                else:
                    total = 6
            save_file_list.append(str(total * j + int(save_file_split[2])))
            save_file_list.extend(save_file_split[3:])
            final_save_file = args.save_path + os.sep + '.'.join(
                save_file_list)
            logger.info('Saving to %s' % final_save_file)
            torch.save(dataset_list[j], final_save_file)
        datasets_list = []
    gc.collect()
示例#6
0
 line_src = src.readline()
 line_tgt = tgt.readline()
 if not line_src and not line_tgt:
     break
 if len(hyp_src_lines) % 10 == 0:
     clear()
     print(len(hyp_src_lines))
 hyp = [
     word_tokenize(s, language="russian")
     for s in sent_tokenize(line_src, language="russian")
 ]
 ref = [
     word_tokenize(s, language="russian")
     for s in sent_tokenize(line_tgt, language="russian")
 ]
 result_idx = greedy_selection(hyp, ref, args.count)
 if len(result_idx) == 0:
     print("Skip!")
     continue
 result_idx.sort()
 hyp_src_line = " ".join(
     [" ".join(hyp[idx]) for idx in result_idx])
 try:
     scores = rouge.get_scores(hyp_src_line, line_tgt)
 except:
     hyp_src_line = "rouge_error"
 hyp_src_lines.append(hyp_src_line)
 ref_src_line = " ".join([" ".join(r) for r in ref])
 ref_src_line = ref_src_line.replace(" .", ".").replace(
     " ,", ",").replace(" ?", "?").replace(" !", "!").replace(
         " »", "»").replace(" «", "«")