forked from DianboWork/SPN4RE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
100 lines (80 loc) · 4.68 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse, os, torch
import random
import numpy as np
from utils.data import build_data
from trainer.trainer import Trainer
from models.setpred4RE import SetPred4RE
def str2bool(v):
return v.lower() in ('true')
def add_argument_group(name):
arg = parser.add_argument_group(name)
return arg
def get_args():
args, unparsed = parser.parse_known_args()
if len(unparsed) > 1:
print("Unparsed args: {}".format(unparsed))
return args, unparsed
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
parser = argparse.ArgumentParser()
data_arg = add_argument_group('Data')
# data_arg.add_argument('--dataset_name', type=str, default="NYT-exact")
# data_arg.add_argument('--train_file', type=str, default="./data/NYT/exact_data/train.json")
# data_arg.add_argument('--valid_file', type=str, default="./data/NYT/exact_data/valid.json")
# data_arg.add_argument('--test_file', type=str, default="./data/NYT/exact_data/test.json")
# data_arg.add_argument('--dataset_name', type=str, default="NYT-partial")
# data_arg.add_argument('--train_file', type=str, default="./data/NYT/casrel_data/new_train.json")
# data_arg.add_argument('--valid_file', type=str, default="./data/NYT/casrel_data/new_valid.json")
# data_arg.add_argument('--test_file', type=str, default="./data/NYT/casrel_data/new_test.json")
data_arg.add_argument('--dataset_name', type=str, default="WebNLG")
data_arg.add_argument('--train_file', type=str, default="./data/WebNLG/clean_WebNLG/new_train.json")
data_arg.add_argument('--valid_file', type=str, default="./data/WebNLG/clean_WebNLG/new_valid.json")
data_arg.add_argument('--test_file', type=str, default="./data/WebNLG/clean_WebNLG/new_test.json")
data_arg.add_argument('--generated_data_directory', type=str, default="./data/generated_data/")
data_arg.add_argument('--generated_param_directory', type=str, default="./data/generated_data/model_param/")
data_arg.add_argument('--bert_directory', type=str, default="./bert_base_cased/")
data_arg.add_argument("--partial", type=str2bool, default=False)
learn_arg = add_argument_group('Learning')
learn_arg.add_argument('--model_name', type=str, default="Set-Prediction-Networks")
learn_arg.add_argument('--num_generated_triples', type=int, default=10)
learn_arg.add_argument('--num_decoder_layers', type=int, default=3)
learn_arg.add_argument('--matcher', type=str, default="avg", choices=['avg', 'min'])
learn_arg.add_argument('--na_rel_coef', type=float, default=1)
learn_arg.add_argument('--rel_loss_weight', type=float, default=1)
learn_arg.add_argument('--head_ent_loss_weight', type=float, default=2)
learn_arg.add_argument('--tail_ent_loss_weight', type=float, default=2)
learn_arg.add_argument('--fix_bert_embeddings', type=str2bool, default=True)
learn_arg.add_argument('--batch_size', type=int, default=8)
learn_arg.add_argument('--max_epoch', type=int, default=50)
learn_arg.add_argument('--gradient_accumulation_steps', type=int, default=1)
learn_arg.add_argument('--decoder_lr', type=float, default=2e-5)
learn_arg.add_argument('--encoder_lr', type=float, default=1e-5)
learn_arg.add_argument('--lr_decay', type=float, default=0.01)
learn_arg.add_argument('--weight_decay', type=float, default=1e-5)
learn_arg.add_argument('--max_grad_norm', type=float, default=0)
learn_arg.add_argument('--optimizer', type=str, default='AdamW', choices=['Adam', 'AdamW'])
evaluation_arg = add_argument_group('Evaluation')
evaluation_arg.add_argument('--n_best_size', type=int, default=100)
evaluation_arg.add_argument('--max_span_length', type=int, default=12) #NYT webNLG 10
misc_arg = add_argument_group('MISC')
misc_arg.add_argument('--refresh', type=str2bool, default=False)
misc_arg.add_argument('--use_gpu', type=str2bool, default=True)
misc_arg.add_argument('--visible_gpu', type=int, default=1)
misc_arg.add_argument('--random_seed', type=int, default=1)
args, unparsed = get_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.visible_gpu)
for arg in vars(args):
print(arg, ":", getattr(args, arg))
set_seed(args.random_seed)
data = build_data(args)
model = SetPred4RE(args, data.relational_alphabet.size())
trainer = Trainer(model, data, args)
trainer.train_model()