-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_exp.py
52 lines (42 loc) · 1.92 KB
/
run_exp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import argparse
import glob
import os
from pathlib import Path
import numpy as np
import torch
import yaml
from train import run
def main():
parser = argparse.ArgumentParser(description='Fast weights Training')
parser.add_argument('--exp', default='all', help='run a single experiment or all experiments')
parser.add_argument('--seed', default=1234, type=int, help='random seed')
opts = parser.parse_args()
exp_files = sorted(glob.glob('exps/*/config.yml'))
exp_already_run = [str(Path(f).parent / 'config.yml') for f in glob.glob('exps/*/*_testinglog.npy')]
exp_to_run = list(set(exp_files) - set(exp_already_run))
if opts.exp == 'all':
for exp in exp_to_run:
with open(exp, 'r') as f:
args = yaml.safe_load(f)
if not os.path.exists(os.path.join(args['log_dir'], args['name'])):
os.mkdir(os.path.join(args['log_dir'], args['name']))
if not os.path.exists(os.path.join(args['dir'], args['config']['output_dir'])):
os.mkdir(os.path.join(args['dir'], args['config']['output_dir']))
print(f'Working in experiment {args["name"]}')
run(args, random_seed=opts.seed)
elif opts.exp in exp_files:
if opts.exp in exp_to_run:
with open(opts.exp, 'r') as f:
args = yaml.safe_load(f)
if not os.path.exists(os.path.join(args['log_dir'], args['name'])):
os.mkdir(os.path.join(args['log_dir'], args['name']))
if not os.path.exists(os.path.join(args['dir'], args['config']['output_dir'])):
os.mkdir(os.path.join(args['dir'], args['config']['output_dir']))
print(f'Working in experiment {args["name"]}')
run(args, random_seed=opts.seed)
else:
raise Exception('This is a done experiment')
else:
raise Exception('Unknown experiment')
if __name__ == "__main__":
main()