def load_sr_model(hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, fine_tune, device, ace_classifier=None, add_object=False, load_partial=False): if ace_classifier is None: ace_classifier = ACEClassifier(hps["wemb_dim"], hps["oc"], hps["ae_oc"], device) if add_object: mymodel = SRModel_Object(hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, device, ace_classifier) if fine_tune is not None: mymodel.load_model(fine_tune, load_partial=load_partial) mymodel.to(device) return mymodel else: mymodel = SRModel(hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, device, ace_classifier) if fine_tune is not None: mymodel.load_model(fine_tune, load_partial=load_partial) mymodel.to(device) return mymodel
def load_ee_model(hps, fine_tune, pretrained_embedding, device, ace_classifier=None): assert pretrained_embedding is not None if ace_classifier is None: ace_classifier = ACEClassifier(2 * hps["lstm_dim"], hps["oc"], hps["ae_oc"], device) if fine_tune is None: return EDModel(hps, device, pretrained_embedding, ace_classifier) else: # ace_classifier.load_model(fine_tune_classifier) mymodel = EDModel(hps, device, pretrained_embedding, ace_classifier) mymodel.load_model(fine_tune) mymodel.to(device) return mymodel
def run(self): print("Running on", self.a.device) self.set_device(self.a.device) np.random.seed(self.a.seed) torch.manual_seed(self.a.seed) torch.backends.cudnn.benchmark = True #################### loading event extraction dataset #################### if self.a.test_ee: log('testing event extraction corpus from %s' % self.a.test_ee) if self.a.test_ee: log('testing event extraction corpus from %s' % self.a.test_ee) # both for grounding and ee WordsField = Field(lower=True, include_lengths=True, batch_first=True) PosTagsField = Field(lower=True, batch_first=True) EntityLabelsField = MultiTokenField(lower=False, batch_first=True) AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True) EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False) # only for ee LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None) EventsField = EventField(lower=False, batch_first=True) SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) if self.a.amr: colcc = 'simple-parsing' else: colcc = 'combined-parsing' print(colcc) train_ee_set = ACE2005Dataset(path=self.a.train_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=1) dev_ee_set = ACE2005Dataset(path=self.a.dev_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=0) # test_ee_set = ACE2005Dataset(path=self.a.test_ee, # fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), # "pos-tags": ("POSTAGS", PosTagsField), # "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), # colcc: ("ADJM", AdjMatrixField), # "golden-event-mentions": ("LABEL", LabelField), # "all-events": ("EVENT", EventsField), # "all-entities": ("ENTITIES", EntitiesField)}, # amr=self.a.amr, keep_events=0) print('self.a.train_ee', self.a.train_ee) LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL) print('LabelField.vocab.stoi', LabelField.vocab.stoi) EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT) print('EventsField.vocab.stoi', EventsField.vocab.stoi) print('len(EventsField.vocab.itos)', len(EventsField.vocab.itos)) print('len(EventsField.vocab.stoi)', len(EventsField.vocab.stoi)) #################### loading SR dataset #################### # both for grounding and sr if self.a.train_sr: log('loading corpus from %s' % self.a.train_sr) transform = transforms.Compose([ transforms.Resize(256), transforms.RandomHorizontalFlip(), transforms.RandomCrop(224), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) vocab_noun = Vocab(os.path.join(self.a.vocab, 'vocab_situation_noun.pkl'), load=True) vocab_role = Vocab(os.path.join(self.a.vocab, 'vocab_situation_role.pkl'), load=True) vocab_verb = Vocab(os.path.join(self.a.vocab, 'vocab_situation_verb.pkl'), load=True) # only need get_role_mask() and sr_mapping() train_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb, EventsField.vocab.stoi, LabelField.vocab.stoi, self.a.imsitu_ontology_file, self.a.train_sr, self.a.verb_mapping_file, None, None, 0, transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, load_object=False, filter_place=self.a.filter_place) #################### loading grounding dataset #################### if self.a.train_grounding: log('loading grounding corpus from %s' % self.a.train_grounding) # only for grounding IMAGEIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) # IMAGEField = SparseField(sequential=False, use_vocab=False, batch_first=True) train_grounding_set = GroundingDataset(path=self.a.train_grounding, img_dir=None, fields={"id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr) dev_grounding_set = GroundingDataset(path=self.a.dev_grounding, img_dir=None, fields={"id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr) # test_grounding_set = GroundingDataset(path=self.a.test_grounding, # img_dir=None, # fields={"id": ("IMAGEID", IMAGEIDField), # "sentence_id": ("SENTID", SENTIDField), # "words": ("WORDS", WordsField), # "pos-tags": ("POSTAGS", PosTagsField), # "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), # colcc: ("ADJM", AdjMatrixField), # "all-entities": ("ENTITIES", EntitiesField), # # "image": ("IMAGE", IMAGEField), # }, # transform=transform, # amr=self.a.amr) #################### build vocabulary #################### if self.a.webd: pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15)) WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS, vectors=pretrained_embedding) else: WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS) PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS, train_grounding_set.POSTAGS, dev_grounding_set.POSTAGS) EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS, train_grounding_set.ENTITYLABELS, dev_grounding_set.ENTITYLABELS) consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME] # print("O label is", consts.O_LABEL) consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME] # print("O label for AE is", consts.ROLE_O_LABEL) # dev_ee_set1 = ACE2005Dataset(path=self.a.dev_ee, # fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), # "pos-tags": ("POSTAGS", PosTagsField), # "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), # colcc: ("ADJM", AdjMatrixField), # "golden-event-mentions": ("LABEL", LabelField), # "all-events": ("EVENT", EventsField), # "all-entities": ("ENTITIES", EntitiesField)}, # amr=self.a.amr, keep_events=1, only_keep=True) # # test_ee_set1 = ACE2005Dataset(path=self.a.test_ee, # fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), # "pos-tags": ("POSTAGS", PosTagsField), # "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), # colcc: ("ADJM", AdjMatrixField), # "golden-event-mentions": ("LABEL", LabelField), # "all-events": ("EVENT", EventsField), # "all-entities": ("ENTITIES", EntitiesField)}, # amr=self.a.amr, keep_events=1, only_keep=True) # print("train set length", len(train_ee_set)) # # print("dev set length", len(dev_ee_set)) # print("dev set 1/1 length", len(dev_ee_set1)) # # print("test set length", len(test_ee_set)) # print("test set 1/1 length", len(test_ee_set1)) # sr model initialization if not self.a.sr_hps_path: self.a.sr_hps = eval(self.a.sr_hps) embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to(self.device) embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to(self.device) embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to(self.device) if "wvemb_size" not in self.a.sr_hps: self.a.sr_hps["wvemb_size"] = len(vocab_verb.id2word) if "wremb_size" not in self.a.sr_hps: self.a.sr_hps["wremb_size"] = len(vocab_role.id2word) if "wnemb_size" not in self.a.sr_hps: self.a.sr_hps["wnemb_size"] = len(vocab_noun.id2word) # if "ae_oc" not in self.a.sr_hps: # self.a.sr_hps["ae_oc"] = len(vocab_role.id2word) # self.a.ee_label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5 # self.a.ee_label_weight[consts.O_LABEL] = 1.0 # self.a.ee_arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5 # if not self.a.ee_hps_path: # self.a.ee_hps = eval(self.a.ee_hps) # if "wemb_size" not in self.a.ee_hps: # self.a.ee_hps["wemb_size"] = len(WordsField.vocab.itos) # if "pemb_size" not in self.a.ee_hps: # self.a.ee_hps["pemb_size"] = len(PosTagsField.vocab.itos) # if "psemb_size" not in self.a.ee_hps: # # self.a.ee_hps["psemb_size"] = max([train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2 # self.a.ee_hps["psemb_size"] = max([train_ee_set.longest(), dev_ee_set.longest(), test_ee_set.longest(), train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2 # if "eemb_size" not in self.a.ee_hps: # self.a.ee_hps["eemb_size"] = len(EntityLabelsField.vocab.itos) # if "oc" not in self.a.ee_hps: # self.a.ee_hps["oc"] = len(LabelField.vocab.itos) # if "ae_oc" not in self.a.ee_hps: # self.a.ee_hps["ae_oc"] = len(EventsField.vocab.itos) if "oc" not in self.a.sr_hps: self.a.sr_hps["oc"] = len(LabelField.vocab.itos) if "ae_oc" not in self.a.sr_hps: self.a.sr_hps["ae_oc"] = len(EventsField.vocab.itos) ace_classifier = ACEClassifier(self.a.sr_hps["wemb_dim"], self.a.sr_hps["oc"], self.a.sr_hps["ae_oc"], self.device) ee_model = None # # if self.a.score_ee: # if self.a.finetune_ee: # log('init ee model from ' + self.a.finetune_ee) # ee_model = load_ee_model(self.a.ee_hps, self.a.finetune_ee, WordsField.vocab.vectors, self.device, ace_classifier) # log('ee model loaded, there are %i sets of params' % len(ee_model.parameters_requires_grads())) # else: # ee_model = load_ee_model(self.a.ee_hps, None, WordsField.vocab.vectors, self.device, ace_classifier) # log('ee model created from scratch, there are %i sets of params' % len(ee_model.parameters_requires_grads())) # if self.a.score_sr: if self.a.finetune_sr: log('init sr model from ' + self.a.finetune_sr) sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, self.a.finetune_sr, self.device, ace_classifier, add_object=self.a.add_object) log('sr model loaded, there are %i sets of params' % len(sr_model.parameters_requires_grads())) else: sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, None, self.device, ace_classifier, add_object=self.a.add_object) log('sr model created from scratch, there are %i sets of params' % len(sr_model.parameters_requires_grads())) model = GroundingModel(ee_model, sr_model, self.get_device()) # ee_model = torch.nn.DataParallel(ee_model) # sr_model = torch.nn.DataParallel(sr_model) # model = torch.nn.DataParallel(model) # if self.a.optimizer == "adadelta": # optimizer_constructor = partial(torch.optim.Adadelta, params=model.parameters_requires_grads(), # weight_decay=self.a.l2decay) # elif self.a.optimizer == "adam": # optimizer_constructor = partial(torch.optim.Adam, params=model.parameters_requires_grads(), # weight_decay=self.a.l2decay) # else: # optimizer_constructor = partial(torch.optim.SGD, params=model.parameters_requires_grads(), # weight_decay=self.a.l2decay, # momentum=0.9) # log('optimizer in use: %s' % str(self.a.optimizer)) if not os.path.exists(self.a.out): os.mkdir(self.a.out) # with open(os.path.join(self.a.out, "word.vec"), "wb") as f: # pickle.dump(WordsField.vocab, f) # with open(os.path.join(self.a.out, "pos.vec"), "wb") as f: # pickle.dump(PosTagsField.vocab.stoi, f) # with open(os.path.join(self.a.out, "entity.vec"), "wb") as f: # pickle.dump(EntityLabelsField.vocab.stoi, f) with open(os.path.join(self.a.out, "label.vec"), "wb") as f: pickle.dump(LabelField.vocab.stoi, f) with open(os.path.join(self.a.out, "role.vec"), "wb") as f: pickle.dump(EventsField.vocab.stoi, f) log('init complete\n') # # ee mappings # self.a.ee_word_i2s = WordsField.vocab.itos # self.a.ee_label_i2s = LabelField.vocab.itos # self.a.ee_role_i2s = EventsField.vocab.itos # self.a.ee_role_mask = None # if self.a.apply_ee_role_mask: # self.a.ee_role_mask = event_role_mask(self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi, EventsField.vocab.stoi, self.device) # sr mappings self.a.sr_word_i2s = vocab_noun.id2word self.a.sr_label_i2s = vocab_verb.id2word # LabelField.vocab.itos self.a.sr_role_i2s = vocab_role.id2word self.a.role_masks = train_sr_set.get_role_mask().to_dense().to(self.device) writer = SummaryWriter(os.path.join(self.a.out, "exp")) self.a.writer = writer # loading testing data # voa_text = self.a.test_voa_text voa_image_dir = self.a.test_voa_image gt_voa_image = self.a.gt_voa_image gt_voa_text = self.a.gt_voa_text gt_voa_align =self.a.gt_voa_align sr_verb_mapping, sr_role_mapping = train_sr_set.get_sr_mapping() test_m2e2_set = M2E2Dataset(path=gt_voa_text, img_dir=voa_image_dir, fields={"image": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), }, transform=transform, amr=self.a.amr, load_object=self.a.add_object, object_ontology_file=self.a.object_class_map_file, object_detection_pkl_file=self.a.object_detection_pkl_file, object_detection_threshold=self.a.object_detection_threshold, keep_events=self.a.keep_events, ) object_results, object_label, object_detection_threshold = test_m2e2_set.get_object_results() # build batch on cpu test_m2e2_iter = BucketIterator(test_m2e2_set, batch_size=1, train=False, shuffle=False, device=-1, sort_key=lambda x: len(x.POSTAGS)) # scores = 0.0 # now_bad = 0 # restart_used = 0 print("\nStarting testing...\n") # lr = parser.lr # optimizer = optimizer_constructor(lr=lr) # ee_tester = EDTester(LabelField.vocab.itos, EventsField.vocab.itos, self.a.ignore_time_test) # sr_tester = SRTester() # g_tester = GroundingTester() j_tester = JointTester(self.a.ignore_place_sr_test, self.a.ignore_time_test) # if self.a.visual_voa_ee_path is not None: # ee_visualizer = EDVisualizer(self.a.gt_voa_text) # else: # ee_visualizer = None image_gt = json.load(open(gt_voa_image)) # all_y = [] # all_y_ = [] # all_events = [] # all_events_ = [] vision_result = dict() # if self.a.visual_voa_g_path is not None and not os.path.exists(self.a.visual_voa_g_path): # os.makedirs(self.a.visual_voa_g_path, exist_ok=True) # if self.a.visual_voa_ee_path is not None and not os.path.exists(self.a.visual_voa_ee_path): # os.makedirs(self.a.visual_voa_ee_path, exist_ok=True) if self.a.visual_voa_sr_path is not None and not os.path.exists(self.a.visual_voa_sr_path): os.makedirs(self.a.visual_voa_sr_path, exist_ok=True) # grounding_writer = open(self.a.visual_voa_g_path, 'w') doc_done = set() with torch.no_grad(): model.eval() for batch in test_m2e2_iter: vision_result = joint_test_batch( model_g=model, batch_g=batch, device=self.device, transform=transform, img_dir=voa_image_dir, # ee_hyps=self.a.ee_hps, # ee_word_i2s=self.a.ee_word_i2s, # ee_label_i2s=self.a.ee_label_i2s, # ee_role_i2s=self.a.ee_role_i2s, # ee_tester=ee_tester, # ee_visualizer=ee_visualizer, sr_noun_i2s=self.a.sr_word_i2s, sr_verb_i2s=self.a.sr_label_i2s, sr_role_i2s=self.a.sr_role_i2s, # sr_tester=sr_tester, role_masks=self.a.role_masks, # ee_role_mask=self.a.ee_role_mask, # j_tester=j_tester, image_gt=image_gt, verb2type=sr_verb_mapping, role2role=sr_role_mapping, vision_result=vision_result, # all_y=all_y, # all_y_=all_y_, # all_events=all_events, # all_events_=all_events_, # visual_g_path=self.a.visual_voa_g_path, # visual_ee_path=self.a.visual_voa_ee_path, load_object=self.a.add_object, object_results=object_results, object_label=object_label, object_detection_threshold=object_detection_threshold, vocab_objlabel=vocab_noun.word2id, # apply_ee_role_mask=self.a.apply_ee_role_mask keep_events_sr=self.a.keep_events_sr, doc_done=doc_done, ) print('vision_result size', len(vision_result)) # pickle.dump(vision_result, open(os.path.join(self.a.out, 'vision_result.pkl'), 'w')) # ep, er, ef = ee_tester.calculate_report(all_y, all_y_, transform=True) # ap, ar, af = ee_tester.calculate_sets(all_events, all_events_) # if self.a.visual_voa_ee_path is not None: # ee_visualizer.rewrite_brat(self.a.visual_voa_ee_path, self.a.visual_voa_ee_gt_ann) # # print('text ep, er, ef', ep, er, ef) # print('text ap, ar, af', ap, ar, af) evt_p, evt_r, evt_f1, role_scores = j_tester.calculate_report( vision_result, voa_image_dir, self.a.visual_voa_sr_path, self.a.add_object, keep_events_sr=self.a.keep_events_sr )#consts.O_LABEL, consts.ROLE_O_LABEL) print('image event ep, er, ef \n', evt_p, '\n', evt_r, '\n', evt_f1) # if not self.a.add_object: # print('image att_iou ap, ar, af', role_scores['role_att_iou_p'], role_scores['role_att_iou_r'], # role_scores['role_att_iou_f1']) # print('image att_hit ap, ar, af', role_scores['role_att_hit_p'], role_scores['role_att_hit_r'], # role_scores['role_att_hit_f1']) # print('image att_cor ap, ar, af', role_scores['role_att_cor_p'], role_scores['role_att_cor_r'], # role_scores['role_att_cor_f1']) # else: # print('image obj_iou ap, ar, af', role_scores['role_obj_iou_p'], role_scores['role_obj_iou_r'], # role_scores['role_obj_iou_f1']) # print('image obj_iou_union ap, ar, af', role_scores['role_obj_iou_union_p'], role_scores['role_obj_iou_union_r'], # role_scores['role_obj_iou_union_f1']) for key in role_scores: print(key) for key_ in role_scores[key]: print(key_, role_scores[key][key_])
def run(self): print("Running on", self.a.device) self.set_device(self.a.device) np.random.seed(self.a.seed) torch.manual_seed(self.a.seed) torch.backends.cudnn.benchmark = True #################### loading event extraction dataset #################### if self.a.train_ee: log('loading event extraction corpus from %s' % self.a.train_ee) # both for grounding and ee WordsField = Field(lower=True, include_lengths=True, batch_first=True) PosTagsField = Field(lower=True, batch_first=True) EntityLabelsField = MultiTokenField(lower=False, batch_first=True) AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True) EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False) # only for ee LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None) EventsField = EventField(lower=False, batch_first=True) SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) if self.a.amr: colcc = 'simple-parsing' else: colcc = 'combined-parsing' print(colcc) train_ee_set = ACE2005Dataset(path=self.a.train_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=1) dev_ee_set = ACE2005Dataset(path=self.a.dev_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=0) test_ee_set = ACE2005Dataset(path=self.a.test_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=0) if self.a.webd: pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15)) LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL, vectors=pretrained_embedding) EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT, vectors=pretrained_embedding) else: LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL) EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT) # add role mask self.a.role_mask = event_role_mask(self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi, EventsField.vocab.stoi, self.device) #################### loading SR dataset #################### # both for grounding and sr if self.a.train_sr: log('loading corpus from %s' % self.a.train_sr) transform = transforms.Compose([ transforms.Resize(256), transforms.RandomHorizontalFlip(), transforms.RandomCrop(224), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) vocab_noun = Vocab(os.path.join(self.a.vocab, 'vocab_situation_noun.pkl'), load=True) vocab_role = Vocab(os.path.join(self.a.vocab, 'vocab_situation_role.pkl'), load=True) vocab_verb = Vocab(os.path.join(self.a.vocab, 'vocab_situation_verb.pkl'), load=True) # train_sr_loader = imsitu_loader(self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, self.a.imsitu_ontology_file, # self.a.train_sr, self.a.verb_mapping_file, self.a.role_mapping_file, # self.a.object_class_map_file, self.a.object_detection_pkl_file, # self.a.object_detection_threshold, # transform, self.a.batch, shuffle=self.a.shuffle, num_workers=1) #self.a.shuffle # dev_sr_loader = imsitu_loader(self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, self.a.imsitu_ontology_file, # self.a.dev_sr, self.a.verb_mapping_file, self.a.role_mapping_file, # self.a.object_class_map_file, self.a.object_detection_pkl_file, # self.a.object_detection_threshold, # transform, self.a.batch, shuffle=self.a.shuffle, num_workers=1) # test_sr_loader = imsitu_loader(self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, self.a.imsitu_ontology_file, # self.a.test_sr, self.a.verb_mapping_file, self.a.role_mapping_file, # self.a.object_class_map_file, self.a.object_detection_pkl_file, # self.a.object_detection_threshold, # transform, self.a.batch, shuffle=self.a.shuffle, num_workers=1) train_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb, LabelField.vocab.stoi, EventsField.vocab.stoi, self.a.imsitu_ontology_file, self.a.train_sr, self.a.verb_mapping_file, self.a.object_class_map_file, self.a.object_detection_pkl_file, self.a.object_detection_threshold, transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, load_object=self.a.add_object, filter_place=self.a.filter_place) dev_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb, LabelField.vocab.stoi, EventsField.vocab.stoi, self.a.imsitu_ontology_file, self.a.dev_sr, self.a.verb_mapping_file, self.a.object_class_map_file, self.a.object_detection_pkl_file, self.a.object_detection_threshold, transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, load_object=self.a.add_object, filter_place=self.a.filter_place) test_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb, LabelField.vocab.stoi, EventsField.vocab.stoi, self.a.imsitu_ontology_file, self.a.test_sr, self.a.verb_mapping_file, self.a.object_class_map_file, self.a.object_detection_pkl_file, self.a.object_detection_threshold, transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, load_object=self.a.add_object, filter_place=self.a.filter_place) #################### loading grounding dataset #################### if self.a.train_grounding: log('loading grounding corpus from %s' % self.a.train_grounding) # only for grounding IMAGEIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) # IMAGEField = SparseField(sequential=False, use_vocab=False, batch_first=True) train_grounding_set = GroundingDataset(path=self.a.train_grounding, img_dir=self.a.img_dir_grounding, fields={"id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr, load_object=self.a.add_object, object_ontology_file=self.a.object_class_map_file, object_detection_pkl_file=self.a.object_detection_pkl_file_g, object_detection_threshold=self.a.object_detection_threshold, ) dev_grounding_set = GroundingDataset(path=self.a.dev_grounding, img_dir=self.a.img_dir_grounding, fields={"id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr, load_object=self.a.add_object, object_ontology_file=self.a.object_class_map_file, object_detection_pkl_file=self.a.object_detection_pkl_file_g, object_detection_threshold=self.a.object_detection_threshold, ) test_grounding_set = GroundingDataset(path=self.a.test_grounding, img_dir=self.a.img_dir_grounding, fields={"id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr, load_object=self.a.add_object, object_ontology_file=self.a.object_class_map_file, object_detection_pkl_file=self.a.object_detection_pkl_file_g, object_detection_threshold=self.a.object_detection_threshold, ) #################### build vocabulary #################### if self.a.webd: pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15)) WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS, vectors=pretrained_embedding) else: WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS) PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS, train_grounding_set.POSTAGS, dev_grounding_set.POSTAGS) EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS, train_grounding_set.ENTITYLABELS, dev_grounding_set.ENTITYLABELS) consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME] # print("O label is", consts.O_LABEL) consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME] # print("O label for AE is", consts.ROLE_O_LABEL) dev_ee_set1 = ACE2005Dataset(path=self.a.dev_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=1, only_keep=True) test_ee_set1 = ACE2005Dataset(path=self.a.test_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=1, only_keep=True) print("train set length", len(train_ee_set)) print("dev set length", len(dev_ee_set)) print("dev set 1/1 length", len(dev_ee_set1)) print("test set length", len(test_ee_set)) print("test set 1/1 length", len(test_ee_set1)) # sr model initialization if not self.a.sr_hps_path: self.a.sr_hps = eval(self.a.sr_hps) embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to(self.device) embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to(self.device) embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to(self.device) if "wvemb_size" not in self.a.sr_hps: self.a.sr_hps["wvemb_size"] = len(vocab_verb.id2word) if "wremb_size" not in self.a.sr_hps: self.a.sr_hps["wremb_size"] = len(vocab_role.id2word) if "wnemb_size" not in self.a.sr_hps: self.a.sr_hps["wnemb_size"] = len(vocab_noun.id2word) self.a.ee_label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5 self.a.ee_label_weight[consts.O_LABEL] = 1.0 self.a.ee_arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5 self.a.ee_hps = eval(self.a.ee_hps) if "wemb_size" not in self.a.ee_hps: self.a.ee_hps["wemb_size"] = len(WordsField.vocab.itos) if "pemb_size" not in self.a.ee_hps: self.a.ee_hps["pemb_size"] = len(PosTagsField.vocab.itos) if "psemb_size" not in self.a.ee_hps: # self.a.ee_hps["psemb_size"] = max([train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2 self.a.ee_hps["psemb_size"] = max([train_ee_set.longest(), dev_ee_set.longest(), test_ee_set.longest(), train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2 if "eemb_size" not in self.a.ee_hps: self.a.ee_hps["eemb_size"] = len(EntityLabelsField.vocab.itos) if "oc" not in self.a.ee_hps: self.a.ee_hps["oc"] = len(LabelField.vocab.itos) if "ae_oc" not in self.a.ee_hps: self.a.ee_hps["ae_oc"] = len(EventsField.vocab.itos) if "oc" not in self.a.sr_hps: self.a.sr_hps["oc"] = len(LabelField.vocab.itos) if "ae_oc" not in self.a.sr_hps: self.a.sr_hps["ae_oc"] = len(EventsField.vocab.itos) ee_tester = EDTester(LabelField.vocab.itos, EventsField.vocab.itos, self.a.ignore_time_test) sr_tester = SRTester() g_tester = GroundingTester() j_tester = JointTester(self.a.ignore_place_sr_test, self.a.ignore_time_test) ace_classifier = ACEClassifier(2 * self.a.ee_hps["lstm_dim"], self.a.ee_hps["oc"], self.a.ee_hps["ae_oc"], self.device) if self.a.finetune_ee: log('init ee model from ' + self.a.finetune_ee) ee_model = load_ee_model(self.a.ee_hps, self.a.finetune_ee, WordsField.vocab.vectors, self.device, ace_classifier) log('ee model loaded, there are %i sets of params' % len(ee_model.parameters_requires_grads())) else: ee_model = load_ee_model(self.a.ee_hps, None, WordsField.vocab.vectors, self.device, ace_classifier) log('ee model created from scratch, there are %i sets of params' % len(ee_model.parameters_requires_grads())) if self.a.finetune_sr: log('init sr model from ' + self.a.finetune_sr) sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, self.a.finetune_sr, self.device, ace_classifier, add_object=self.a.add_object, load_partial=True) log('sr model loaded, there are %i sets of params' % len(sr_model.parameters_requires_grads())) else: sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, None, self.device, ace_classifier, add_object=self.a.add_object, load_partial=True) log('sr model created from scratch, there are %i sets of params' % len(sr_model.parameters_requires_grads())) model = GroundingModel(ee_model, sr_model, self.get_device()) # ee_model = torch.nn.DataParallel(ee_model) # sr_model = torch.nn.DataParallel(sr_model) # model = torch.nn.DataParallel(model) if self.a.optimizer == "adadelta": optimizer_constructor = partial(torch.optim.Adadelta, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay) elif self.a.optimizer == "adam": optimizer_constructor = partial(torch.optim.Adam, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay) else: optimizer_constructor = partial(torch.optim.SGD, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay, momentum=0.9) log('optimizer in use: %s' % str(self.a.optimizer)) if not os.path.exists(self.a.out): os.mkdir(self.a.out) with open(os.path.join(self.a.out, "word.vec"), "wb") as f: pickle.dump(WordsField.vocab, f) with open(os.path.join(self.a.out, "pos.vec"), "wb") as f: pickle.dump(PosTagsField.vocab.stoi, f) with open(os.path.join(self.a.out, "entity.vec"), "wb") as f: pickle.dump(EntityLabelsField.vocab.stoi, f) with open(os.path.join(self.a.out, "label.vec"), "wb") as f: pickle.dump(LabelField.vocab.stoi, f) with open(os.path.join(self.a.out, "role.vec"), "wb") as f: pickle.dump(EventsField.vocab.stoi, f) with open(os.path.join(self.a.out, "ee_hyps.json"), "w") as f: json.dump(self.a.ee_hps, f) with open(os.path.join(self.a.out, "sr_hyps.json"), "w") as f: json.dump(self.a.sr_hps, f) log('init complete\n') # ee mappings self.a.ee_word_i2s = WordsField.vocab.itos self.a.ee_label_i2s = LabelField.vocab.itos self.a.ee_role_i2s = EventsField.vocab.itos # sr mappings self.a.sr_word_i2s = vocab_noun.id2word self.a.sr_label_i2s = vocab_verb.id2word # LabelField.vocab.itos self.a.sr_role_i2s = vocab_role.id2word writer = SummaryWriter(os.path.join(self.a.out, "exp")) self.a.writer = writer joint_train( model_ee=ee_model, model_sr=sr_model, model_g=model, train_set_g=train_grounding_set, dev_set_g=dev_grounding_set, test_set_g=test_grounding_set, train_set_ee=train_ee_set, dev_set_ee=dev_ee_set, test_set_ee=test_ee_set, train_set_sr=train_sr_set, dev_set_sr=dev_sr_set, test_set_sr=test_sr_set, optimizer_constructor=optimizer_constructor, epochs=self.a.epochs, ee_tester=ee_tester, sr_tester=sr_tester, g_tester=g_tester, j_tester=j_tester, parser=self.a, other_testsets={ "dev ee 1/1": dev_ee_set1, "test ee 1/1": test_ee_set1, }, transform=transform, vocab_objlabel=vocab_noun.word2id ) log('Done!')