-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
166 lines (152 loc) · 5.44 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import argparse
import yaml
from processor.processor import Processor, init_seed
def get_parser():
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
# parameter priority: command line > config > default
parser = argparse.ArgumentParser(
description='Spatial Temporal Graph Convolution Network')
parser.add_argument(
'--work-dir',
default='./work_dir/temp',
help='the work folder for storing results')
parser.add_argument('-model_saved_name', default='')
parser.add_argument(
'--config',
default='./config/nturgbd-cross-view/test_bone.yaml',
help='path to the configuration file')
# processor
parser.add_argument(
'--phase', default='train', help='must be train or test')
parser.add_argument(
'--save-score',
type=str2bool,
default=False,
help='if ture, the classification score will be stored')
# visulize and debug
parser.add_argument(
'--seed', type=int, default=1, help='random seed for pytorch')
parser.add_argument(
'--log-interval',
type=int,
default=100,
help='the interval for printing messages (#iteration)')
parser.add_argument(
'--save-interval',
type=int,
default=1,
help='the interval for storing models (#iteration)')
parser.add_argument(
'--eval-interval',
type=int,
default=5,
help='the interval for evaluating models (#iteration)')
parser.add_argument(
'--print-log',
type=str2bool,
default=True,
help='print logging or not')
# feeder
parser.add_argument(
'--feeder', default='feeder.feeder', help='data loader will be used')
parser.add_argument(
'--num-worker',
type=int,
default=32,
help='the number of worker for data loader')
parser.add_argument(
'--train-feeder-args',
default=dict(),
help='the arguments of data loader for training')
parser.add_argument(
'--test-feeder-args',
default=dict(),
help='the arguments of data loader for test')
# model
parser.add_argument('--model', default=None, help='the model will be used')
parser.add_argument(
'--model-args',
type=dict,
default=dict(),
help='the arguments of model')
parser.add_argument(
'--weights',
default=None,
help='the weights for network initialization')
parser.add_argument(
'--ignore-weights',
type=str,
default=[],
nargs='+',
help='the name of weights which will be ignored in the initialization')
# optim
parser.add_argument(
'--base-lr', type=float, default=0.01, help='initial learning rate')
parser.add_argument(
'--step',
type=int,
default=[20, 40, 60],
nargs='+',
help='the epoch where optimizer reduce the learning rate')
parser.add_argument(
'--device',
type=int,
default=0,
nargs='+',
help='the indexes of GPUs for training or testing')
parser.add_argument('--optimizer', default='Adam', help='type of optimizer')
parser.add_argument(
'--nesterov', type=str2bool, default=False, help='use nesterov or not')
parser.add_argument(
'--batch-size', type=int, default=256, help='training batch size')
parser.add_argument(
'--test-batch-size', type=int, default=256, help='test batch size')
parser.add_argument(
'--start-epoch',
type=int,
default=0,
help='start training from which epoch')
parser.add_argument(
'--num-epoch',
type=int,
default=80,
help='stop training in which epoch')
parser.add_argument(
'--weight-decay',
type=float,
default=0.0005,
help='weight decay for optimizer')
parser.add_argument('--only_train_part', default=False)
parser.add_argument('--only_train_epoch', default=0)
parser.add_argument('--warm_up_epoch', default=0)
parser.add_argument('--num_of_weeks', type=int, default=0, help='The previous week')
parser.add_argument('--num_of_days', type=int, default=0, help='The previous day')
parser.add_argument('--num_of_hours', type=int, default=1, help='The previous hour')
parser.add_argument('--num_for_predict', type=int, default=12, help='Prediction interval')
parser.add_argument('--points_per_hour', type=int, default=12, help='The number of point per hour')
parser.add_argument('--num_of_vertices', type=int, default=358, help='The number of vertices')
parser.add_argument('--gen_config_args', type=dict, default=dict(), help='The config of data generate')
return parser
if __name__ == '__main__':
parser = get_parser()
# load arg form config file
p = parser.parse_args()
if p.config is not None:
with open(p.config, 'r') as f:
default_arg = yaml.load(f)
key = vars(p).keys()
for k in default_arg.keys():
if k not in key:
print('WRONG ARG: {}'.format(k))
assert (k in key)
parser.set_defaults(**default_arg)
arg = parser.parse_args()
init_seed(0)
processor = Processor(arg)
processor.start()