-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualize_model_seg.py
72 lines (55 loc) · 2.4 KB
/
visualize_model_seg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import wiki_loader
import gensim
import evaluate
import utils
from pathlib2 import Path
from argparse import ArgumentParser
import torch
import choiloader
import numpy as np
import glob
from choiloader import clean_paragraph
section_delimiter = "-----"
def segment(path, model, word2vec, output_folder, wiki = False):
for filename in glob.glob(path+ '*.txt'):
with open(filename, "r+") as f:
paragraph = f.read()
sentences = [clean_paragraph(paragraph)]
cutoffs = evaluate.predict_cutoffs(sentences, model, word2vec)
total = []
segment = []
for i, (sentence, cutoff) in enumerate(zip(sentences, cutoffs)):
segment.append(sentence)
if cutoff:
full_segment ='.'.join(segment) + '.'
full_segment = full_segment + '\n' + section_delimiter + '\n'
total.append(full_segment)
segment = []
file_id = str(filename).split('/')[-1:][0]
# Model does not return prediction for last sentence
segment.append(sentences[-1:][0])
total.append('.'.join(segment))
output_file_content = "".join(total)
output_file_full_path = Path(output_folder).joinpath(Path(file_id))
with output_file_full_path.open('w') as f:
f.write(output_file_content)
def main(args):
utils.read_config_file(args.config)
utils.config.update(args.__dict__)
if not args.test:
word2vec = gensim.models.KeyedVectors.load_word2vec_format(utils.config['word2vecfile'], binary=True)
else:
word2vec = None
with open(args.model, 'rb') as f:
model = torch.load(f)
model.eval()
segment(args.path, model, word2vec, args.output, wiki=args.wiki)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--test', help='Test mode? (e.g fake word2vec)', action='store_true')
parser.add_argument('--model', help='Model to run - will import and run', required=True)
parser.add_argument('--config', help='Path to config.json', default='./config.json')
parser.add_argument('--path', help='Path to files to segment by model', default='./data/Dataset/test-data/')
parser.add_argument('--output', help='output folder', required=True)
parser.add_argument('--wiki', help='use wikipedia files', action='store_true')
main(parser.parse_args())