def generate_counts(example): source = example['spacy_source_toks'] target = example['spacy_target_toks'] source_sents = sents_from_html(source) target_sents = sents_from_html(target) source_toks = sent_toks_from_html(source) target_toks = sent_toks_from_html(target) source_sent_lens, target_sent_lens = [], [] for source_sent in source_sents: source_sent_lens.append(len(source_sent.split(' '))) for target_sent in target_sents: target_sent_lens.append(len(target_sent.split(' '))) num_docs = len(re.findall(r'd note_id', source)) return ({ 'mrn': example['mrn'], 'account': example['account'], 'source_toks': len(source_toks), 'target_toks': len(target_toks), 'source_sents': len(source_sents), 'target_sents': len(target_sents), 'source_docs': num_docs, 'target_docs': 1, }, { 'source_sent_lens': source_sent_lens, 'target_sent_lens': target_sent_lens })
def build(sent_order, sents, record, target_tok_ct=TARGET_TOK_CT): summary_sents = [] sum_len = 0 sent_lens = [] assert len(sents) > 0 # build summaries by total length for sent_idx in sent_order: sent = sents[sent_idx] if sent in summary_sents: continue this_len = len(sent.split(' ')) if sum_len + this_len > target_tok_ct and not len(summary_sents) == 0: break sent_lens.append(this_len) summary_sents.append(sent) sum_len += this_len prediction = ' <s> '.join(summary_sents).strip() sent_order_used = sent_order[:len(summary_sents)] target_sents = sents_from_html(record['spacy_target_toks'], convert_lower=True) n = len(target_sents) reference = ' <s> '.join(target_sents).strip() ref_len = len(reference.split(' ')) - n # subtract pseudo sentence tokens return { 'mrn': record['mrn'], 'account': record['account'], 'reference': reference, 'prediction': prediction, 'ref_len': ref_len, 'sum_len': sum_len, 'sent_order': stringify_list(sent_order_used) }, sent_lens
def gen_summaries(record): target_sents = sents_from_html(resolve_course(record['spacy_target_toks']), convert_lower=True) source_sents = sents_from_html(record['spacy_source_toks'], convert_lower=True) pred_obj = summarizer(source_sents, target_sents) n = len(target_sents) reference = ' <s> '.join(target_sents).strip() ref_len = len(reference.split(' ')) - n # subtract pseudo sentence tokens obj = { 'account': record['account'], 'mrn': record['mrn'], 'reference': reference, 'ref_len': ref_len, } obj.update(pred_obj) return obj
def compute_lr_stats(record): sents = list(set(sents_from_html(record['spacy_source_toks'], convert_lower=True))) reference = prepare_str_for_rouge(record['spacy_target_toks'].lower()) lr_scores = np.array(list(lxr.rank_sentences( sents, threshold=0.1, fast_power_method=True, ))) n = len(sents) predictions = [prepare_str_for_rouge(s) for s in sents] references = [reference for _ in range(n)] rouge_types = ['rouge1', 'rouge2'] outputs = compute(predictions, references, rouge_types=rouge_types, use_aggregator=False) r_scores = np.array( [sum([outputs[t][i].fmeasure for t in rouge_types]) / float(len(rouge_types)) for i in range(n)]) return list(zip(lr_scores, r_scores))
def compute_lr(record): sents = sents_from_html(record['spacy_source_toks'], convert_lower=True) unique_sents = list(set(sents)) sent_scores = np.array(list(lxr.rank_sentences( sents, threshold=0.1, fast_power_method=True, ))) frac_uniq = len(unique_sents) / float(len(sents)) sent_scores_deduped = np.array(list(lxr.rank_sentences( unique_sents, threshold=0.1, fast_power_method=True, ))) sent_order = np.argsort(-sent_scores) sent_order_deduped = np.argsort(-sent_scores_deduped) o1, sent_lens = build(sent_order, sents, record) o2, _ = build(sent_order_deduped, unique_sents, record) return o1, o2, np.array(sent_lens).mean(), frac_uniq
def gen_summaries(record): target_sents = sents_from_html(record['spacy_target_toks'], convert_lower=True) summary_sents = list( map( lambda sent: bm25.get_top_n( gen_query(sent.split(' ')), corpus, n=1)[0], target_sents)) n = len(target_sents) reference = ' <s> '.join(target_sents).strip() summary = ' <s> '.join(summary_sents).strip() ref_len = len(reference.split(' ')) - n # subtract pseudo sentence tokens sum_len = len(summary.split(' ')) - n # subtract pseudo sentence tokens return { 'account': record['account'], 'mrn': record['mrn'], 'prediction': summary, 'reference': reference, 'sum_len': sum_len, 'ref_len': ref_len, }
def aggregate(source_str): return ' '.join(sents_from_html(source_str, convert_lower=True))
def generate_samples(row): """ :param row: :return: """ rouge_types = ['rouge1', 'rouge2'] single_extraction_examples = [] rouge_diffs = defaultdict(float) rouge_gains = defaultdict(float) rouge_fulls = defaultdict(float) source_sents = sents_from_html(row['spacy_source_toks'], convert_lower=True) target_sents = sents_from_html(row['spacy_target_toks'], convert_lower=True) target_n = len(target_sents) if not eval_mode and target_n > MAX_TARGET_SENTS: return [], rouge_diffs, rouge_gains, rouge_fulls target = ' '.join(target_sents) target_no_stop = prepare_str_for_rouge(target) target_toks = set(target_no_stop.split(' ')) source_sents_no_stop = list(map(prepare_str_for_rouge, source_sents)) dup_idxs = set() seen = set() for idx, source_sent in enumerate(source_sents_no_stop): if source_sent in seen: dup_idxs.add(idx) else: seen.add(source_sent) # remove duplicate sentences and 1-2 word sentences (too many of them) and most are not necessary for BHC should_keep_all = eval_mode or type == 'test' keep_idxs = [ idx for idx, s in enumerate(source_sents_no_stop) if extraction_is_keep( s, target_toks, no_match_keep_prob=compute_no_match_keep_prob(len(source_sents), should_keep_all) ) and idx not in dup_idxs ] source_sents_no_stop_filt = [source_sents_no_stop[idx] for idx in keep_idxs] source_sents_filt = [source_sents[idx] for idx in keep_idxs] source_n = len(keep_idxs) if not should_keep_all and (source_n < target_n or source_n > MAX_SOURCE_SENTS): return [], rouge_diffs, rouge_gains, rouge_fulls curr_sum_sents = [] curr_rouge = 0.0 included_sent_idxs = set() max_sum_n = min(source_n, len(target_sents), MAX_SUM_SENTS) if eval_mode: max_sum_n = 1 references = [target_no_stop for _ in range(source_n)] for gen_idx in range(max_sum_n): curr_sum = prepare_str_for_rouge(' '.join(curr_sum_sents).strip() + ' ') predictions = [(curr_sum + s).strip() for s in source_sents_no_stop_filt] outputs = compute(predictions=predictions, references=references, rouge_types=rouge_types, use_aggregator=False) scores = np.array( [sum([outputs[t][i].fmeasure for t in rouge_types]) / float(len(rouge_types)) for i in range(source_n)]) scores_pos_mask = scores.copy() if len(included_sent_idxs) > 0: scores[list(included_sent_idxs)] = float('-inf') scores_pos_mask[list(included_sent_idxs)] = float('inf') max_idx = int(np.argmax(scores)) max_score = scores[max_idx] assert max_idx not in included_sent_idxs min_score = scores_pos_mask.min() max_differential = max_score - min_score max_gain = max_score - curr_rouge valid_scores = [] for score in scores: if score > -1: valid_scores.append(score) rouge_diffs[gen_idx] = max_differential rouge_gains[gen_idx] = max_score - np.mean(valid_scores) rouge_fulls[gen_idx] = max_score if max_gain < MIN_ROUGE_IMPROVEMENT or max_differential < MIN_ROUGE_DIFFERENTIAL: break eligible_scores = [] eligible_source_sents = [] for i in range(len(scores)): if i not in included_sent_idxs: eligible_scores.append(scores[i]) eligible_source_sents.append(source_sents_filt[i]) # Example example = { 'mrn': row['mrn'], 'account': row['account'], 'curr_sum_sents': curr_sum_sents.copy(), 'candidate_source_sents': eligible_source_sents, 'curr_rouge': curr_rouge, 'target_rouges': eligible_scores, 'target_sents': target_sents, } single_extraction_examples.append(example) curr_rouge = max_score curr_sum_sents.append(source_sents_filt[max_idx]) included_sent_idxs.add(max_idx) assert len(curr_sum_sents) == len(set(curr_sum_sents)) return single_extraction_examples, rouge_diffs, rouge_gains, rouge_fulls
parser = argparse.ArgumentParser('Script to generate NSP scores') parser.add_argument('--pretrained_model', default='emilyalsentzer/Bio_ClinicalBERT', choices=['emilyalsentzer/Bio_ClinicalBERT']) parser.add_argument('--max_n', default=-1, type=int) args = parser.parse_args() mini = 0 <= args.max_n <= 100 validation_df = get_records(split='validation', mini=mini) records = validation_df.to_dict('records') if args.max_n > 0: np.random.seed(1992) records = np.random.choice(records, size=args.max_n, replace=False) target_sents = [ sents_from_html(record['spacy_target_toks']) for record in records ] n = len(records) print('Loading tokenizer...') tokenizer = BertTokenizer.from_pretrained(args.pretrained_model) print('Loading model...') model = BertForNextSentencePrediction.from_pretrained( args.pretrained_model, return_dict=True) print('Generating NSP predictions for {} examples'.format(n)) outputs = list(tqdm(map(process, target_sents), total=n)) agg_output = defaultdict(list) for output in outputs: for k, v in output.items(): agg_output[k] += v