-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
81 lines (67 loc) · 2.92 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import datetime
import argparse
from model import SentimentModel
from data import (
load_dataset, generate_train_data, generate_train_sequences,
generate_inference_sequence)
from config import DEFAULT_LOG_DIR, MAX_WORDS
NUM_TRAINING_SAMPLES = 500000
def create_model(log_dir=None, val_tuple=None, vocab=None):
return SentimentModel(config={
'vocab_size': MAX_WORDS,
'log_dir': log_dir,
'tensorboard': {
'vocab': vocab,
'val_tuple': val_tuple,
},
})
def train(dataset, logs):
log_dir = '{}/iter_{:%Y%m%dT%H%M%S}'.format(
args.logs, datetime.datetime.now())
if not os.path.exists(log_dir):
os.makedirs(log_dir)
vocab, train_tuple, val_tuple = generate_train_data(
load_dataset(dataset), log_dir, NUM_TRAINING_SAMPLES)
training_seq, validation_seq = generate_train_sequences(
train_tuple, val_tuple)
model = create_model(log_dir, val_tuple, vocab)
model.train(training_seq, validation_seq)
def inference(weights_path, tokenizer_path, text):
model = create_model()
model.load_weights(weights_path)
sentiment, category = model.analyze(
generate_inference_sequence(text, tokenizer_path))
print('Text sentiment is {}, and it\'s mostly about {}.'.format(
sentiment, category))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Attention-Based LSTM for sentiment analysis')
parser.add_argument('command',
metavar='<command>',
help='`train` or `inference`')
parser.add_argument('--logs', required=False,
default=DEFAULT_LOG_DIR,
metavar='/path/to/logs/',
help='Logs dir (default=/tmp/sentiment-analysis)')
parser.add_argument('--dataset', required=False,
metavar='path to dataset',
help='Path to dataset directory')
parser.add_argument('--weights', required=False,
metavar='path to saved weights',
help='Path to saved .h5 weights')
parser.add_argument('--tokenizer', required=False,
metavar='path to saved tokenizer',
help='Path to saved tokenizer')
parser.add_argument('--text', required=False,
metavar='text to analyse',
help='Text to analyse')
args = parser.parse_args()
if args.command == 'train':
assert args.dataset, 'Argument --dataset is required for training'
train(args.dataset, args.logs)
if args.command == 'inference':
assert args.weights, 'Argument --weights is required for inference'
assert args.tokenizer, 'Argument --tokenizer is required for inference'
assert args.text, 'Argument --text is required for inference'
inference(args.weights, args.tokenizer, args.text)