-
Notifications
You must be signed in to change notification settings - Fork 0
/
my_train_net.py
75 lines (56 loc) · 2.07 KB
/
my_train_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Modules used
from multiprocessing import Queue
# From Library by authors
from lib.solver import Solver
from lib.data_io import category_model_id_pair
from lib.data_process import kill_processes, make_data_processes
# My reimplementation
from my_3DR2N2.my_res_gru_net import My_ResidualGRUNet
# Define globally accessible queues, will be used for clean exit when force
train_queue, validation_queue, train_processes, val_processes = None, None, None, None
# Clean up in case of unexpected quit
def cleanup_handle(func):
def func_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except:
print('Wait until the dataprocesses to end')
kill_processes(train_queue, train_processes)
kill_processes(validation_queue, val_processes)
raise
return func_wrapper
# Train the network
@cleanup_handle
def train_net():
# Set up the model and the solver
my_net = My_ResidualGRUNet()
# Generate the solver
solver = Solver(my_net)
# Load the global variables
global train_queue, validation_queue, train_processes, val_processes
# Initialize the queues
train_queue = Queue(15) # maximum number of minibatches that can be put in a data queue)
validation_queue = Queue(15)
# Train on 80 percent of the data
train_dataset_portion = [0, 0.8]
# Validate on 20 percent of the data
test_dataset_portion = [0.8, 1]
# Establish the training procesesses
train_processes = make_data_processes(
train_queue,
category_model_id_pair(dataset_portion=train_dataset_portion),
1,
repeat=True)
# Establish the validation procesesses
val_processes = make_data_processes(
validation_queue,
category_model_id_pair(dataset_portion=test_dataset_portion),
1,
repeat=True,
train=False)
# Train the network
solver.train(train_queue, validation_queue)
# Cleanup the processes and the queue.
kill_processes(train_queue, train_processes)
kill_processes(validation_queue, val_processes)
train_net()