-
Notifications
You must be signed in to change notification settings - Fork 1
/
visualize_model_seg.py
84 lines (60 loc) · 2.49 KB
/
visualize_model_seg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gensim
import evaluate
import utils
from pathlib2 import Path
from argparse import ArgumentParser
import torch
import choiloader
import numpy as np
goldset_delimiter = "********"
section_delimiter = "========"
def segment(path, model, word2vec, output_folder):
file_id = str(path).split('/')[-1:][0]
splited_sentences, target, _ = choiloader.read_choi_file(path, word2vec, False, False)
sentences = [' '.join(s) for s in splited_sentences]
gold_set = np.zeros(len(splited_sentences)).astype(int)
gold_set[np.asarray(target)] = 1
cutoffs = evaluate.predict_cutoffs(sentences, model, word2vec)
total = []
segment = []
for i, (sentence, cutoff) in enumerate(zip(sentences, cutoffs)):
segment.append(sentence)
if cutoff or gold_set[i] == 1:
full_segment ='\n'.join(segment) + '.\n'
if cutoff:
full_segment = full_segment + '\n' + section_delimiter + '\n'
if gold_set[i] == 1:
full_segment = full_segment + goldset_delimiter + '\n'
else:
full_segment = full_segment + '\n' + goldset_delimiter + '\n'
total.append(full_segment + '\n')
segment = []
# Model does not return prediction for last sentence
segment.append(sentences[-1:][0])
total.append('.'.join(segment) + '\n')
output_file_content = "".join(total)
output_file_full_path = Path(output_folder).joinpath(Path(file_id))
print(output_file_full_path)
with output_file_full_path.open('w') as f:
f.write(output_file_content)
def main(args):
utils.read_config_file(args.config)
utils.config.update(args.__dict__)
with Path(args.file).open('r') as f:
file_names = f.read().strip().split('\n')
word2vec = None
with open(args.model, 'rb') as f:
#model = torch.load(f)
#for run in cpu
model = torch.load(f, map_location='cpu')
model.eval()
for name in file_names:
if name:
segment(Path(name), model, word2vec, args.output)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--model', help='Model to run - will import and run', required=True)
parser.add_argument('--config', help='Path to config.json', default='config.json')
parser.add_argument('--file', help='file containing file names to segment by model', required=True)
parser.add_argument('--output', help='output folder', required=True)
main(parser.parse_args())