コード例 #1
0
    def __call__(self):

        # 1. Fit the vocabulary (read all the text and decide which tokens to keep in the vocabulary)
        start_time = time()
        self.sle.fit(self.files_to_process)
        self.vocab_size = self.sle.features_n
        logging.info(f" - labels fit done {time() - start_time: 2f} seconds")

        # 2. Parse all the files and generate co-occurrence matrices for each file in a sparse format
        start_time = time()
        cooc_iterator = p_uimap(self.get_cooc_mat, self.files_to_process)

        cooc = coo_matrix((self.sle.features_n, self.sle.features_n))
        for partial_cooc in cooc_iterator:
            # 3. add all the co-occurrence matrices together
            cooc += partial_cooc
        if self.temp_dir is not None:
            shutil.rmtree(self.temp_dir)
            logging.info(f"temp folder {self.temp_dir} deleted")
        logging.info(
            f" - cooc dict computed {time() - start_time: 2f} seconds")

        # 4. Format the sparse matrix to get indices and values separated
        start_time = time()
        cooc_rows, cooc_cols, cooc_data = self.glove_formatter(cooc)
        logging.info(f" - output formated {time() - start_time: 2f} seconds")
        return cooc_rows, cooc_cols, cooc_data, cooc
コード例 #2
0
    def fit(self, X):
        """

        Creates the vocabulary

        :param X: can be an array, a list or a list of files
        :return: None
        """

        # 1. Create counter
        # 1.a Create counter from file
        if isinstance(X, str):
            self.counter = self.count_in_file(X)

        # 1.b Create counter from multiple files in parallel
        elif isinstance(X, list) and isinstance(X[0], str):
            self.counter = Counter()
            counter_iterator = p_uimap(self.count_in_file, X)

            for temp_counter in counter_iterator:
                self.counter += temp_counter

        #  1.c Use counter Object directly
        else:
            self.counter = Counter(X)

        # 2. Remove UNK token if present in the vocabulary (it will be added at the beginning of the vocabulary)
        if self.unk_token in self.counter:
            self.counter.pop(self.unk_token)

        # 3. Keep only tokens that appeared more than min_occurrence times
        self.counter = Counter(
            {k: c for k, c in self.counter.most_common(self.max_features - 1)
             if c >= self.min_occurrence}
        )

        # 4. Instantiate dictionaries val2ix, ix2val and list of classes
        self.val2ix[self.unk_token] = 0
        self.ix2val[0] = self.unk_token
        self.classes_ = [self.unk_token]

        # 5. Fill those dictionaries with the vocabulary and indexes
        for ix, (val, _) in enumerate(self.counter.most_common()):

            self.val2ix[val] = ix + 1
            self.ix2val[ix + 1] = val
            self.classes_.append(val)

        self.features_n = len(self.classes_)
コード例 #3
0
 def populate_cache(self):
     print("Populating image cache for {}.".format(self.split))
     for image_name, image, mask in p_uimap(self.load_image, self.images):
         self.cache[image_name] = (image, mask)
コード例 #4
0
ファイル: ML_plot.py プロジェクト: ESMartiny/NetworkSIR
            # Get the network hashes
            network_hashes = set([
                utils.cfg_to_hash(cfg.network, exclude_ID=False)
                for cfg in cfgs
            ])

            # Get list of unique cfgs
            cfgs_network = []
            for cfg in cfgs:
                network_hash = utils.cfg_to_hash(cfg.network, exclude_ID=False)

                if network_hash in network_hashes:
                    cfgs_network.append(cfg)
                    network_hashes.remove(network_hash)

            # Generate the networks
            print("Generating networks. Please wait")
            p_umap(f_single_network, cfgs_network, num_cpus=num_cores)

            # Then run the simulations on the network
            print("Running simulations. Please wait")
            f_single_simulation = partial(simulation.run_single_simulation,
                                          verbose=False)
            for cfg in p_uimap(f_single_simulation, cfgs, num_cpus=num_cores):
                simulation.update_database(db_cfg, q, cfg)

print(
    f"\n{N_files:,} files were generated, total duration {utils.format_time(t.elapsed)}"
)
print("Finished simulating!")
コード例 #5
0
    mini_str = '_mini' if args.mini else ''
    eval_str = '_eval' if args.eval_mode else ''
    types = ['validation'] if args.eval_mode else ['validation', 'train']
    for type in types:
        print('Getting records for {} set'.format(type))
        type_df = get_records(split=type, mini=args.mini)
        if not eval_mode and len(type_df) > MAX_SUMMARIES:
            print('Shrinking from {} to {}'.format(len(type_df), MAX_SUMMARIES))
            type_df = type_df.sample(n=MAX_SUMMARIES, replace=False)
        type_examples = type_df.to_dict('records')
        n = len(type_examples)
        print('Processing {} examples for {} set'.format(n, type))
        if args.single_proc:
            x = list(tqdm(map(generate_samples, type_examples), total=n))
        else:
            x = list(p_uimap(generate_samples, type_examples,  num_cpus=0.8))

        single_extraction_examples = [a[0] for a in x]
        rouge_diffs = [a[1] for a in x]
        rouge_gains = [a[2] for a in x]
        rouge_fulls = [a[3] for a in x]
        output = list(itertools.chain(*single_extraction_examples))
        out_n = len(output)
        account_n = len(set([x['account'] for x in output]))
        out_fn = os.path.join(out_dir, 'single_extraction_labels_{}{}{}.json'.format(type, eval_str, mini_str))
        print('Saving {} labeled single step extraction samples for {} visits to {}'.format(out_n, account_n, out_fn))
        with open(out_fn, 'w') as fd:
            json.dump(output, fd)

        all_rouge_diffs = defaultdict(list)
        all_rouge_gains = defaultdict(list)
コード例 #6
0
        mel_path = (cm.mel_dir / file_name).with_suffix('.npy')
        np.save(mel_path, mel)
        return (file_name, mel.shape[0])

    print(
        f"Creating mels from all wavs found in {metadatareader.data_directory}"
    )
    print(f"\nMels will be stored stored under")
    print(f"{cm.mel_dir}")
    (cm.mel_dir).mkdir(exist_ok=True)
    audio = Audio(config=cm.config)
    wav_files = [metadatareader.wav_paths[k] for k in metadatareader.wav_paths]
    len_dict = {}
    remove_files = []
    mel_lens = []
    wav_iter = p_uimap(process_wav, wav_files)
    for (fname, mel_len) in wav_iter:
        len_dict.update({fname: mel_len})
        if mel_len > cm.config['max_mel_len'] or mel_len < cm.config[
                'min_mel_len']:
            remove_files.append(fname)
        else:
            mel_lens.append(mel_len)

    pickle.dump(len_dict, open(cm.data_dir / 'mel_len.pkl', 'wb'))
    pickle.dump(remove_files,
                open(cm.data_dir / 'under-over_sized_mels.pkl', 'wb'))
    summary_manager.add_histogram('Mel Lengths', values=np.array(mel_lens))
    total_mel_len = np.sum(mel_lens)
    total_wav_len = total_mel_len * audio.config['hop_length']
    summary_manager.display_scalar('Total duration (hours)',
コード例 #7
0
ファイル: main.py プロジェクト: griff4692/clin-sum
    srl_packed_source = replace_paragraphs(example['source_str'], full_packed_source.split(' <p> '))
    srl_packed_target = replace_paragraphs(course_str, full_packed_target.split(' <p> '))

    new_example['srl_packed_source'] = srl_packed_source
    new_example['srl_packed_target'] = srl_packed_target
    new_example['long_fragments'] = FRAG_DELIM.join(long_frags)

    return new_example


if __name__ == '__main__':
    parser = argparse.ArgumentParser('Script to pack source and summary with entity and copy-paste fragments.')
    parser.add_argument('-mini', default=False, action='store_true')
    parser.add_argument('--max_n', default=None, type=int)

    args = parser.parse_args()

    splits = ['validation', 'train']
    mini_str = '_small' if args.mini else ''
    out_fn = os.path.join(out_dir, 'srl_packed_examples{}.csv'.format(mini_str))
    examples = get_records(split=splits, mini=args.mini).to_dict('records')
    if args.max_n is not None:
        examples = np.random.choice(examples, size=args.max_n, replace=False)
        out_fn = os.path.join(out_dir, 'srl_packed_examples_{}.csv'.format(str(args.max_n)))
    n = len(examples)
    examples_packed = list(p_uimap(pack_example, examples, num_cpus=0.8))
    print('Done! Now saving {} packed examples to {}'.format(len(examples_packed), out_fn))
    packed_examples_df = pd.DataFrame(examples_packed)
    packed_examples_df.to_csv(out_fn, index=False)
コード例 #8
0
ファイル: run_oracle.py プロジェクト: griff4692/clin-sum
        summarizer = greedy_rel_rouge
    elif args.strategy == 'greedy_rel_recall':
        summarizer = greedy_rel_rouge_recall
    elif args.strategy == 'top_k':
        summarizer = top_k_rouge
    elif args.strategy == 'top_k_recall':
        summarizer = top_k_rouge_recall
    elif args.strategy == 'random_recall':
        summarizer = random_recall
    elif args.strategy == 'random':
        summarizer = random

    if args.max_n > 0:
        np.random.seed(1992)
        records = np.random.choice(records, size=args.max_n, replace=False)

    outputs = list(filter(None, p_uimap(gen_summaries, records, num_cpus=0.8)))
    n = len(outputs)
    exp_str = 'oracle_{}'.format(args.strategy)
    alias_str = args.custom_path_alias or 'validation'
    if 'recall' in args.strategy:
        exp_str += '_{}'.format(args.recall_target_n)
    out_fn = os.path.join(out_dir, 'predictions',
                          '{}_{}.csv'.format(exp_str, alias_str))
    print('Saving {} predictions to {}'.format(n, out_fn))
    print(
        'To evaluate, run: cd ../evaluations && python rouge.py --experiment {}'
        .format(exp_str))
    output_df = pd.DataFrame(outputs)
    output_df.to_csv(out_fn, index=False)
コード例 #9
0
    'spacy_target_toks', 'spacy_source_tok_ct', 'spacy_target_tok_ct',
    'coverage', 'density', 'compression', 'fragments', 'is_too_big'
]


def collect_examples(mrn):
    mrn_dir = os.path.join(out_dir, 'mrn', str(mrn))
    df = pd.read_csv(os.path.join(mrn_dir, 'examples.csv'))
    assert len(df) > 0
    return df[COLS]


if __name__ == '__main__':
    _, _, mrns = get_mrn_status_df('valid_example')
    fn = os.path.join(out_dir, 'full_examples.csv')
    df = pd.concat(list(p_uimap(collect_examples, mrns, num_cpus=0.8)))
    df.to_csv(fn, index=False)

    small_mrns = set(np.random.choice(mrns, size=100, replace=False))
    tiny_mrns = set(np.random.choice(list(small_mrns), size=10, replace=False))

    small_fn = os.path.join(out_dir, 'full_examples_small.csv')
    small_df = df[df['mrn'].isin(small_mrns)]
    print('Saving {} examples to {}'.format(len(small_df), small_fn))
    small_df.to_csv(small_fn, index=False)

    tiny_fn = os.path.join(out_dir, 'full_examples_tiny.csv')
    tiny_df = small_df[small_df['mrn'].isin(tiny_mrns)]
    print('Saving {} examples to {}'.format(len(tiny_df), tiny_fn))
    tiny_df.to_csv(tiny_fn, index=False)
コード例 #10
0
ファイル: vocab.py プロジェクト: griff4692/clin-sum
        vocab = Vocab()
        vocab.add_tokens([x[0] for x in MATCHES])

        for i in tqdm(range(len(prev_vocab))):
            tok = prev_vocab.i2w[i]
            sup = prev_vocab.support[i]
            if sup >= args.min_tf:
                tok_adj = cast_num(tok)
                vocab.add_token(tok_adj, sup)

        out_fn = os.path.join('data', 'vocab_num_template.pk')
        print('Vocab reduced from {} to {}'.format(len(prev_vocab),
                                                   len(vocab)))
        print('Saving it now to {}'.format(out_fn))
        with open(out_fn, 'wb') as fd:
            pickle.dump(vocab, fd)
    else:
        records = get_records(split=['train', 'validation']).to_dict('records')
        tokens = p_uimap(get_tokens, records, num_cpus=0.8)
        tokens_flat = list(itertools.chain(*tokens))
        tok_cts = Counter(tokens_flat)
        vocab = Vocab()
        for t, v in tok_cts.items():
            if v >= args.min_tf or np.char.isnumeric(t):
                vocab.add_token(t, token_support=v)
        out_fn = os.path.join('data', 'vocab.pk')
        with open(out_fn, 'wb') as fd:
            pickle.dump(vocab, fd)
        print('Done! Saved vocabulary of size={} to {}'.format(
            len(vocab), out_fn))
コード例 #11
0
ファイル: get_entities.py プロジェクト: griff4692/clin-sum
        cat = CAT(cdb=cdb, vocab=vocab)

        print('Loading Spacy...')
        sentencizer = spacy.load(
            'en_core_sci_lg', disable=['tagger', 'parser', 'ner', 'textcat'])
        sentencizer.add_pipe(sentencizer.create_pipe('sentencizer'))
        print('Loading UMLS entity linker...')
        linker = EntityLinker(resolve_abbreviations=True, name='umls')
        cui_to_ent_map = linker.kb.cui_to_entity
        print('Let\'s go get some entities...')

        splits = ['validation', 'train']
        examples = get_records(split=splits, mini=args.mini).to_dict('records')

        num_ents = np.array(
            list(p_uimap(extract_entities, examples, num_cpus=0.8)))
        print('An average of {} entities extracted per visit'.format(
            num_ents.mean()))

    if args.collect:
        ent_fns = list(
            map(lambda x: os.path.join(entities_dir, x),
                os.listdir(entities_dir)))
        print('Collecting {} different entity files.'.format(len(ent_fns)))
        ent_df_arr = []
        for i in tqdm(range(len(ent_fns))):
            try:
                ent_df_arr.append(pd.read_csv(ent_fns[i]))
            except pd.errors.EmptyDataError:
                print('No entities for {}'.format(ent_fns[i]))
        ent_df = pd.concat(ent_df_arr)
コード例 #12
0
        'mrn': mrn,
        'account': account,
        'split': record['split'],
        'hiv': record['hiv'],
        'num_target_sents': num_target_sents,
        'cui_info': cui_info
    }


if __name__ == '__main__':
    print('Loading Spacy...')
    sentencizer = spacy.load('en_core_sci_lg',
                             disable=['tagger', 'parser', 'ner', 'textcat'])
    sentencizer.add_pipe(sentencizer.create_pipe('sentencizer'))

    print('Getting records')
    splits = ['validation', 'train']
    examples = get_records(split=splits, mini=False).to_dict('records')
    n = len(examples)

    egrids = list(p_uimap(get_grid, examples, num_cpus=0.8))
    out_fn = os.path.join(out_dir, 'egrids.json')
    print('Done! Now saving {} e-grids to {}'.format(len(egrids), out_fn))
    with open(out_fn, 'w') as fd:
        json.dump(egrids, fd)

    egrids_small = list(np.random.choice(egrids, size=1000, replace=False))
    out_fn = os.path.join(out_dir, 'egrids_small.json')
    with open(out_fn, 'w') as fd:
        json.dump(egrids_small, fd)
コード例 #13
0
    examples_df['is_too_big'] = is_too_big
    # Generally, this means we have a repeated hospital course for some reason
    examples_df.drop_duplicates(subset=['spacy_target_tok_ct'], inplace=True)
    examples_df.to_csv(examples_fn, index=False)
    return len(examples_df), source_sent_lens, target_sent_lens, too_big_ct


if __name__ == '__main__':
    print('Loading scispacy')
    spacy_nlp = spacy.load('en_core_sci_lg',
                           disable=['tagger', 'parser', 'ner', 'textcat'])
    spacy_nlp.add_pipe(spacy_nlp.create_pipe('sentencizer'))
    print('Ready to tokenize!')

    _, _, mrns = get_mrn_status_df('valid_example')
    n = len(mrns)
    print('Processing {} mrns'.format(n))
    outputs = list(p_uimap(tokenize_mrn, mrns, num_cpus=0.8))
    examples = [x[0] for x in outputs]
    source_sent_lens = np.array(list(itertools.chain(*[x[1]
                                                       for x in outputs])))
    target_sent_lens = np.array(list(itertools.chain(*[x[2]
                                                       for x in outputs])))
    num_examples = sum(examples)
    too_big_ct = sum([x[3] for x in outputs])
    print('Tokenized {} examples. {} were flagged as too big'.format(
        num_examples, too_big_ct))

    print('Average source sentence length: {}'.format(source_sent_lens.mean()))
    print('Average target sentence length: {}'.format(target_sent_lens.mean()))
コード例 #14
0
        for k, v in frag_obj.items():
            frag_dicts[k].append(v)

    for k, v in frag_dicts.items():
        examples_df[k] = v
    examples_df.to_csv(examples_fn, index=False)

    return frag_dicts


if __name__ == '__main__':
    _, _, mrns = get_mrn_status_df('valid_example')
    n = len(mrns)
    print('Processing {} mrns'.format(n))
    start_time = time()
    outputs = list(p_uimap(get_extractive_fragments, mrns, num_cpus=0.8))
    duration(start_time)
    stat_names = ['compression', 'coverage', 'density']
    stats = defaultdict(list)
    for output in outputs:
        for stat in stat_names:
            stats[stat] += output[stat]

    df = pd.DataFrame(stats)
    out_fn = '../evaluations/results/extractiveness.csv'
    df.to_csv(out_fn, index=False)

    for stat, vals in stats.items():
        print('Statistic={}...'.format(stat))
        print('\t', describe(vals))
コード例 #15
0
ファイル: simulation.py プロジェクト: ESMartiny/NetworkSIR
def run_simulations(
        simulation_parameters,
        N_runs=2,
        num_cores_max=None,
        N_tot_max=False,
        verbose=False,
        force_rerun=False,
        dry_run=False,
        **kwargs) :

    if isinstance(simulation_parameters, dict) :
        simulation_parameters = utils.format_simulation_paramters(simulation_parameters)
        cfgs_all = utils.generate_cfgs(simulation_parameters, N_runs, N_tot_max, verbose=verbose)

        N_tot_max = utils.d_num_cores_N_tot[utils.extract_N_tot_max(simulation_parameters)]

    elif isinstance(simulation_parameters[0], utils.DotDict) :
        cfgs_all = simulation_parameters

        N_tot_max = np.max([cfg.network.N_tot for cfg in cfgs_all])

    else :
        raise ValueError(f"simulation_parameters not of the correct type")

    if len(cfgs_all) == 0 :
        N_files = 0
        return N_files

    db_cfg = utils.get_db_cfg()
    q = Query()

    db_counts  = np.array([db_cfg.count((q.hash == cfg.hash) & (q.network.ID == cfg.network.ID)) for cfg in cfgs_all])

    assert np.max(db_counts) <= 1

    # keep only cfgs that are not in the database already
    if force_rerun :
        cfgs = cfgs_all
    else :
        cfgs = [cfg for (cfg, count) in zip(cfgs_all, db_counts) if count == 0]

    N_files = len(cfgs)

    num_cores = utils.get_num_cores_N_tot(N_tot_max, num_cores_max)

    if isinstance(simulation_parameters, dict) :
        s_simulation_parameters = str(simulation_parameters)
    elif isinstance(simulation_parameters, list) :
        s_simulation_parameters = f"{len(simulation_parameters)} runs"
    else :
        raise AssertionError("simulation_parameters neither list nor dict")

    print( f"\n\n" f"Generating {N_files :3d} network-based simulations",
           f"with {num_cores} cores",
           f"based on {s_simulation_parameters}.",
           "Please wait. \n",
           flush=True)

    if dry_run or N_files == 0 :
        return N_files

    # kwargs = {}
    if num_cores == 1 :
        for cfg in tqdm(cfgs) :
            cfg_out = run_single_simulation(cfg, save_initial_network=True, verbose=verbose, **kwargs)
            update_database(db_cfg, q, cfg_out)

    else :
        # First generate the networks
        f_single_network = partial(run_single_simulation, only_initialize_network=True, save_initial_network=True, verbose=verbose, **kwargs)

        # Get the network hashes
        network_hashes = set([utils.cfg_to_hash(cfg.network, exclude_ID=False) for cfg in cfgs])

        # Get list of unique cfgs
        cfgs_network = []
        for cfg in cfgs :
            network_hash = utils.cfg_to_hash(cfg.network, exclude_ID=False)

            if network_hash in network_hashes :
                cfgs_network.append(cfg)
                network_hashes.remove(network_hash)

        # Generate the networks
        print("Generating networks. Please wait")
        p_umap(f_single_network, cfgs_network, num_cpus=num_cores)

        # Then run the simulations on the network
        print("Running simulations. Please wait")
        f_single_simulation = partial(run_single_simulation, verbose=verbose, **kwargs)
        for cfg in p_uimap(f_single_simulation, cfgs, num_cpus=num_cores) :
            update_database(db_cfg, q, cfg)

    return N_files
コード例 #16
0
ファイル: compute_lens.py プロジェクト: griff4692/clin-sum
        'source_sent_lens': source_sent_lens,
        'target_sent_lens': target_sent_lens
    })


if __name__ == '__main__':
    parser = argparse.ArgumentParser('Script to compute dataset statistics.')

    args = parser.parse_args()

    in_fn = os.path.join(out_dir, 'full_examples_no_trunc.csv')
    print('Loading data from {}'.format(in_fn))
    df = pd.read_csv(in_fn)
    print('Loaded {} distinct visits'.format(len(df)))
    outputs = list(
        p_uimap(generate_counts, df.to_dict('records'), num_cpus=0.8))

    counts = [output[0] for output in outputs]
    source_sent_lens = np.array(
        list(
            itertools.chain(
                *[output[1]['source_sent_lens'] for output in outputs])))
    target_sent_lens = np.array(
        list(
            itertools.chain(
                *[output[1]['target_sent_lens'] for output in outputs])))

    print('Source sentence length. Mean={}. STD={}.'.format(
        np.mean(source_sent_lens), np.std(source_sent_lens)))
    print('Target sentence length. Mean={}. STD={}.'.format(
        np.mean(target_sent_lens), np.std(target_sent_lens)))
コード例 #17
0
ファイル: linearize_kg.py プロジェクト: sshaar/aqa
if __name__ == '__main__':
    parser = argparse.ArgumentParser('PyTorch Dataset wrapper for Question Generation Task.')
    parser.add_argument('--dataset', default='squad', help='trivia_qa or hotpot_qa')
    parser.add_argument(
        '-debug', default=False, action='store_true', help='If true, run on tiny portion of train dataset')
    args = parser.parse_args()

    dataset = dataset_factory(args.dataset)
    if dataset.name == 'squad':
        dtypes = ['mini'] if args.debug else ['train', 'validation']
    else:
        dtypes = ['mini'] if args.debug else ['train', 'test', 'validation']

    for dtype in dtypes:
        print('Linearizing knowledge graphs for {} set'.format(dtype))
        d = dataset[dtype]
        kg_fn = os.path.join('..', 'data', dataset.name, 'kg_{}.pk'.format(dtype))
        print('Loading knowledge graphs for {} set...'.format(dtype))
        with open(kg_fn, 'rb') as fd:
            kgs = pickle.load(fd)

        n = len(d)
        examples_w_graph = [(example, kgs[example['id']]) for example in d if example['id'] in kgs]
        outputs = list(p_uimap(linearize_graph, examples_w_graph))

        out_fn = os.path.join('..', 'data', dataset.name, 'dataset_{}.csv'.format(dtype))
        df = pd.DataFrame(outputs)
        print('Saving {} examples to {}'.format(df.shape[0], out_fn))
        df.to_csv(out_fn, index=False)
コード例 #18
0
ファイル: run_lexrank.py プロジェクト: griff4692/clin-sum
        train_docs = train_df['spacy_target_toks'].apply(aggregate)
        lxr = LexRank(train_docs, stopwords=stopwords)

        idf_obj = {
            'idf_score': dict(lxr.idf_score),
            'default': lxr.default
        }

        print('Saving IDF so we can don\'t need to recompute it...')
        with open(idf_fn, 'w') as fd:
            json.dump(idf_obj, fd)

    validation_records = get_records(split='validation', mini=False).to_dict('records')

    if args.compute_stats:
        outputs = list(itertools.chain(*list(p_uimap(compute_lr_stats, validation_records, num_cpus=0.8))))
        lr_scores = [o[0] for o in outputs]
        r_scores = [o[1] for o in outputs]
        corel = pearsonr(lr_scores, r_scores)
        print('Pearson correlation of LR and R12={}. p-value={}'.format(corel[0], corel[1]))
    else:
        outputs = list(p_uimap(compute_lr, validation_records, num_cpus=0.8))

        exp_str = 'lr'
        out_fn = os.path.join(out_dir, 'predictions', '{}_validation.csv'.format(exp_str))

        output = [o[0] for o in outputs]
        output_df = pd.DataFrame(output)
        n = len(output_df)
        print('Saving {} predictions to {}'.format(n, out_fn))
        output_df.to_csv(out_fn, index=False)
コード例 #19
0
ファイル: run_retrieval.py プロジェクト: griff4692/clin-sum
    validation_df = get_records(split='validation', mini=mini)
    validation_records = validation_df.to_dict('records')

    if args.max_n > 0:
        np.random.seed(1992)
        validation_records = np.random.choice(validation_records,
                                              size=args.max_n,
                                              replace=False)

    print('Loading BM25...')
    bm25_fn = os.path.join(out_dir, 'bm25.pk')
    with open(bm25_fn, 'rb') as fd:
        bm25 = pickle.load(fd)

    print(
        'Loading original corpus (train sentences) for which BM25 is has indexed...'
    )
    train_sents_fn = os.path.join(out_dir, 'train_sents.csv')
    corpus = pd.read_csv(train_sents_fn).sents.tolist()

    print('Let\'s retrieve!')
    outputs = list(
        filter(None, p_uimap(gen_summaries, validation_records, num_cpus=0.8)))
    out_fn = os.path.join(out_dir, 'predictions', 'retrieval_validation.csv')
    output_df = pd.DataFrame(outputs)
    output_df.to_csv(out_fn, index=False)

    print(
        'To evaluate, run: cd ../evaluations && python rouge.py --experiment retrieval'
    )